Skip to content
Closed
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,7 @@ async def main(cfg: DictConfig):

# initialize before spawning services
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
mlogger = await get_or_create_metric_logger(actor_name="Controller")

# ---- Setup services ---- #
await ts.initialize(strategy=ts.ControllerStorageVolumes())
Expand Down Expand Up @@ -363,6 +362,11 @@ async def main(cfg: DictConfig):
),
)

# Call after services are initialized
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you maybe explain in the comment, why the init_backends should be called after services are initialized?

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling before works for every mode except when 'per_rank_share_run=True'. Then it hangs. wandb says its experimental, and it didnt investigate it more deeply to see if i need to wait for something to finish. But i agree, i will add a note! Edit: done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we debug this further instead of checking in this workaround?

# TODO (felipemello): if called before, and per_rank_share_run=True, it hangs
# probably wandb requires primary runs to finish before shared runs can be initialized
await mlogger.init_backends.call_one(metric_logging_cfg)

print("All services initialized successfully!")

# ---- Core RL loops ---- #
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas


# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ off_by_n: 1 # Off by one by default
rollout_threads: 1 # Recommended to set equal to policy.num_replicas

# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ model: "Qwen/Qwen3-8B"
off_by_n: 1 # Off by one by default

# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
21 changes: 13 additions & 8 deletions apps/toy_rl/toy_metrics/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from monarch.actor import current_rank, endpoint

logging.basicConfig(level=logging.DEBUG)
logging.getLogger("forge.observability.metrics").setLevel(logging.DEBUG)


class TrainActor(ForgeActor):
Expand Down Expand Up @@ -82,31 +83,35 @@ async def main():
group = f"grpo_exp_{int(time.time())}"

# Config format: {backend_name: backend_config_dict}
# Each backend can specify reduce_across_ranks to control distributed logging behavior
config = {
"console": {"reduce_across_ranks": True},
"console": {"logging_mode": "per_rank_reduce"},
"wandb": {
"project": "my_project",
"project": "toy_metrics",
"group": group,
"reduce_across_ranks": False,
# Only useful if NOT reduce_across_ranks.
"share_run_id": False, # Share run ID across ranks -- Not recommended.
"logging_mode": "per_rank_no_reduce",
"per_rank_share_run": False,
},
}

service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False}
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(config)
mlogger = await get_or_create_metric_logger(actor_name="Controller")

# Spawn services first (triggers registrations via provisioner hook)
trainer = await TrainActor.options(**service_config).as_service()
generator = await GeneratorActor.options(**service_config).as_service()

# Call after services are initialized
# TODO (felipemello): if called before, and per_rank_share_run=True, it hangs
# probably wandb requires primary runs to finish before shared runs can be initialized
await mlogger.init_backends.call_one(config)

for i in range(3):
print(f"\n=== Global Step {i} ===")
record_metric("main/global_step", 1, Reduce.MEAN)
await trainer.train_step.fanout(i)
for sub in range(3):
await generator.generate_step.fanout(i, sub)
await asyncio.sleep(0.1)
await mlogger.flush.call_one(i)

# shutdown
Expand Down
8 changes: 6 additions & 2 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@

async def run(cfg: DictConfig):
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
mlogger = await get_or_create_metric_logger(actor_name="Controller")

if (prompt := cfg.get("prompt")) is None:
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
Expand All @@ -37,6 +36,11 @@ async def run(cfg: DictConfig):
print("Spawning service...")
policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)

# Call after services are initialized
# TODO (felipemello): if called before, and per_rank_share_run=True, it hangs
# probably wandb requires primary runs to finish before shared runs can be initialized
await mlogger.init_backends.call_one(metric_logging_cfg)

import time

print("Requesting generation...")
Expand Down
Loading