Skip to content

Commit 25caeb0

Browse files
author
Felipe Mello
committed
fix wandb hang
1 parent 60e6382 commit 25caeb0

File tree

4 files changed

+33
-31
lines changed

4 files changed

+33
-31
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ async def main(cfg: DictConfig):
322322
)
323323
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
324324
mlogger = await get_or_create_metric_logger(process_name="Controller")
325+
await mlogger.init_backends.call_one(metric_logging_cfg)
325326
await ts.initialize(strategy=ts.ControllerStorageVolumes())
326327

327328
# ---- Setup services ---- #
@@ -350,11 +351,6 @@ async def main(cfg: DictConfig):
350351
),
351352
)
352353

353-
# Call after services are initialized
354-
# TODO (felipemello): if called before, and per_rank_share_run=True, it hangs
355-
# probably wandb requires primary runs to finish before shared runs can be initialized
356-
await mlogger.init_backends.call_one(metric_logging_cfg)
357-
358354
print("All services initialized successfully!")
359355

360356
# ---- Core RL loops ---- #

apps/toy_rl/toy_metrics/main.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import asyncio
88

99
import logging
10-
import time
10+
from datetime import datetime
1111

1212
from forge.controller.actor import ForgeActor
1313
from forge.controller.provisioner import shutdown
@@ -17,8 +17,13 @@
1717

1818
from monarch.actor import current_rank, endpoint
1919

20-
logging.basicConfig(level=logging.DEBUG)
21-
logging.getLogger("forge.observability.metrics").setLevel(logging.DEBUG)
20+
logging.basicConfig(
21+
level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
22+
)
23+
logging.getLogger("forge.observability.metrics").setLevel(logging.INFO)
24+
logging.getLogger("forge.observability.metric_actors").setLevel(logging.INFO)
25+
# Reduce wandb logging noise
26+
logging.getLogger("wandb").setLevel(logging.WARNING)
2227

2328

2429
class TrainActor(ForgeActor):
@@ -79,8 +84,7 @@ async def generate_step(self, step: int, substep: int):
7984

8085
# Main
8186
async def main():
82-
"""Example demonstrating distributed metric logging with different backends."""
83-
group = f"grpo_exp_{int(time.time())}"
87+
group = "time" + str(int(datetime.now().timestamp()))
8488

8589
# Config format: {backend_name: backend_config_dict}
8690
config = {
@@ -89,22 +93,18 @@ async def main():
8993
"project": "toy_metrics",
9094
"group": group,
9195
"logging_mode": "per_rank_no_reduce",
92-
"per_rank_share_run": False,
96+
"per_rank_share_run": True,
9397
},
9498
}
9599

96100
service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False}
97101
mlogger = await get_or_create_metric_logger(process_name="Controller")
102+
await mlogger.init_backends.call_one(config)
98103

99-
# Spawn services first (triggers registrations via provisioner hook)
104+
# Spawn services (will register fetchers)
100105
trainer = await TrainActor.options(**service_config).as_service()
101106
generator = await GeneratorActor.options(**service_config).as_service()
102107

103-
# Call after services are initialized
104-
# TODO (felipemello): if called before, and per_rank_share_run=True, it hangs
105-
# probably wandb requires primary runs to finish before shared runs can be initialized
106-
await mlogger.init_backends.call_one(config)
107-
108108
for i in range(3):
109109
print(f"\n=== Global Step {i} ===")
110110
record_metric("main/global_step", 1, Reduce.MEAN)

apps/vllm/main.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727

2828
async def run(cfg: DictConfig):
2929
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
30-
mlogger = await get_or_create_metric_logger(actor_name="Controller")
30+
mlogger = await get_or_create_metric_logger(process_name="Controller")
31+
await mlogger.init_backends.call_one(metric_logging_cfg)
3132

3233
if (prompt := cfg.get("prompt")) is None:
3334
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
@@ -36,11 +37,6 @@ async def run(cfg: DictConfig):
3637
print("Spawning service...")
3738
policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)
3839

39-
# Call after services are initialized
40-
# TODO (felipemello): if called before, and per_rank_share_run=True, it hangs
41-
# probably wandb requires primary runs to finish before shared runs can be initialized
42-
await mlogger.init_backends.call_one(metric_logging_cfg)
43-
4440
import time
4541

4642
print("Requesting generation...")

src/forge/observability/metrics.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,6 @@ async def flush(
568568
states[key] = acc.get_state()
569569
acc.reset()
570570

571-
# Update step (used by NO_REDUCE backends in push)
572-
self.step = step
573-
574571
# Log to PER_RANK_REDUCE backends only (NO_REDUCE already logged in push)
575572
if self.per_rank_reduce_backends:
576573
metrics_for_backends = reduce_metrics_states([states])
@@ -579,6 +576,9 @@ async def flush(
579576
for backend in self.per_rank_reduce_backends:
580577
await backend.log_batch(metrics_for_backends, step)
581578

579+
# Update step (used by NO_REDUCE backends in push)
580+
self.step = step + 1
581+
582582
return states if return_state else {}
583583

584584
async def shutdown(self):
@@ -768,22 +768,32 @@ async def _init_shared_global(self):
768768
settings = wandb.Settings(
769769
mode="shared", x_primary=True, x_label="controller_primary"
770770
)
771-
self.run = wandb.init(project=self.project, group=self.group, settings=settings)
771+
772+
self.run = wandb.init(
773+
project=self.project,
774+
group=self.group,
775+
settings=settings,
776+
)
772777

773778
async def _init_shared_local(self, primary_metadata: Dict[str, Any]):
774779
import wandb
780+
from wandb.sdk.lib.service import service_token
775781

776782
shared_id = primary_metadata.get("shared_run_id")
777783
if shared_id is None:
778784
raise ValueError(
779785
f"Shared ID required but not provided for {self.name} backend init"
780786
)
787+
788+
# Clear any stale service tokens that might be pointing to dead processes
789+
# In multiprocessing environments, WandB service tokens can become stale and point
790+
# to dead service processes. This causes wandb.init() to hang indefinitely trying
791+
# to connect to non-existent services. Clearing forces fresh service connection.
792+
service_token.clear_service_in_env()
793+
781794
settings = wandb.Settings(mode="shared", x_primary=False, x_label=self.name)
782795
self.run = wandb.init(
783-
id=shared_id,
784-
project=self.project,
785-
group=self.group,
786-
settings=settings,
796+
id=shared_id, project=self.project, group=self.group, settings=settings
787797
)
788798

789799
async def log_batch(

0 commit comments

Comments
 (0)