Skip to content
Closed
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,7 @@ async def main(cfg: DictConfig):

# initialize before spawning services
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
mlogger = await get_or_create_metric_logger(actor_name="Controller")

# ---- Setup services ---- #
await ts.initialize(strategy=ts.ControllerStorageVolumes())
Expand Down Expand Up @@ -363,6 +362,9 @@ async def main(cfg: DictConfig):
),
)

# Call after services are initialized
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you maybe explain in the comment, why the init_backends should be called after services are initialized?

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling before works for every mode except when 'per_rank_share_run=True'. Then it hangs. wandb says its experimental, and it didnt investigate it more deeply to see if i need to wait for something to finish. But i agree, i will add a note! Edit: done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we debug this further instead of checking in this workaround?

await mlogger.init_backends.call_one(metric_logging_cfg)

print("All services initialized successfully!")

# ---- Core RL loops ---- #
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas


# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ off_by_n: 1 # Off by one by default
rollout_threads: 1 # Recommended to set equal to policy.num_replicas

# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ model: "Qwen/Qwen3-8B"
off_by_n: 1 # Off by one by default

# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
19 changes: 11 additions & 8 deletions apps/toy_rl/toy_metrics/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from monarch.actor import current_rank, endpoint

logging.basicConfig(level=logging.DEBUG)
logging.getLogger("forge.observability.metrics").setLevel(logging.DEBUG)


class TrainActor(ForgeActor):
Expand Down Expand Up @@ -82,31 +83,33 @@ async def main():
group = f"grpo_exp_{int(time.time())}"

# Config format: {backend_name: backend_config_dict}
# Each backend can specify reduce_across_ranks to control distributed logging behavior
config = {
"console": {"reduce_across_ranks": True},
"console": {"logging_mode": "per_rank_reduce"},
"wandb": {
"project": "my_project",
"project": "toy_metrics",
"group": group,
"reduce_across_ranks": False,
# Only useful if NOT reduce_across_ranks.
"share_run_id": False, # Share run ID across ranks -- Not recommended.
"logging_mode": "per_rank_no_reduce",
"per_rank_share_run": False,
},
}

service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False}
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(config)
mlogger = await get_or_create_metric_logger(actor_name="Controller")

# Spawn services first (triggers registrations via provisioner hook)
trainer = await TrainActor.options(**service_config).as_service()
generator = await GeneratorActor.options(**service_config).as_service()

# Initialize after spawning services
await mlogger.init_backends.call_one(config)

for i in range(3):
print(f"\n=== Global Step {i} ===")
record_metric("main/global_step", 1, Reduce.MEAN)
await trainer.train_step.fanout(i)
for sub in range(3):
await generator.generate_step.fanout(i, sub)
await asyncio.sleep(0.1)
await mlogger.flush.call_one(i)

# shutdown
Expand Down
6 changes: 4 additions & 2 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@

async def run(cfg: DictConfig):
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
mlogger = await get_or_create_metric_logger(actor_name="Controller")

if (prompt := cfg.get("prompt")) is None:
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
Expand All @@ -37,6 +36,9 @@ async def run(cfg: DictConfig):
print("Spawning service...")
policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)

# initialize after spawning services
await mlogger.init_backends.call_one(metric_logging_cfg)

import time

print("Requesting generation...")
Expand Down
Loading
Loading