Skip to content

Commit 8711f02

Browse files
felipemello1Felipe Mello
andauthored
fix - Metric logging work with new monarch API (meta-pytorch#451)
Co-authored-by: Felipe Mello <[email protected]>
1 parent 27ef76d commit 8711f02

File tree

13 files changed

+415
-125
lines changed

13 files changed

+415
-125
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ async def main(cfg: DictConfig):
305305
provisioner = await init_provisioner()
306306

307307
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
308-
mlogger = await get_or_create_metric_logger()
308+
mlogger = await get_or_create_metric_logger(process_name="Controller")
309309
await mlogger.init_backends.call_one(metric_logging_cfg)
310310

311311
# ---- Setup services ---- #

src/forge/controller/provisioner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -310,11 +310,12 @@ def bootstrap(env: dict[str, str]):
310310

311311
self._proc_host_map[procs] = host_mesh
312312

313-
# Spawn local fetcher actor on each process and register with global logger
313+
# Spawn LocalFetcherActor for this ProcMesh and register with GlobalLoggingActor.
314+
# When called, the LocalFetcherActor is broadcast by Monarch to all ranks in the ProcMesh.
314315
if not FORGE_DISABLE_METRICS.get_value():
315316
from forge.observability.metric_actors import get_or_create_metric_logger
316317

317-
_ = await get_or_create_metric_logger(procs)
318+
_ = await get_or_create_metric_logger(procs, process_name=mesh_name)
318319
return procs
319320

320321
async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
@@ -333,14 +334,14 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh):
333334
)
334335
return
335336
async with self._lock:
336-
# Deregister local logger from global logger
337-
if hasattr(proc_mesh, "_local_fetcher"):
337+
# Deregister LocalFetcherActor from GlobalLoggingActor
338+
if hasattr(proc_mesh, "_local_fetcher") and hasattr(proc_mesh, "_uid"):
338339
from forge.observability.metric_actors import (
339340
get_or_create_metric_logger,
340341
)
341342

342343
global_logger = await get_or_create_metric_logger(proc_mesh)
343-
await global_logger.deregister_fetcher.call_one(proc_mesh)
344+
await global_logger.deregister_fetcher.call_one(proc_mesh._uid)
344345

345346
if hasattr(proc_mesh, "_gpu_ids"):
346347
gpu_manager = self._host_gpu_map[proc_mesh._host._host_id]

src/forge/observability/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from .metrics import (
1313
BackendRole,
1414
ConsoleBackend,
15-
get_actor_name_with_rank,
16-
get_logger_backend_class,
1715
LoggerBackend,
1816
MaxAccumulator,
1917
MeanAccumulator,
@@ -29,12 +27,12 @@
2927
WandbBackend,
3028
)
3129
from .perf_tracker import trace, Tracer
30+
from .utils import get_proc_name_with_rank
3231

3332
__all__ = [
3433
# Main API functions
3534
"record_metric",
3635
"reduce_metrics_states",
37-
"get_actor_name_with_rank",
3836
"get_logger_backend_class",
3937
"get_or_create_metric_logger",
4038
# Performance tracking
@@ -45,6 +43,8 @@
4543
"BackendRole",
4644
# Enums
4745
"Reduce",
46+
# Utility functions
47+
"get_proc_name_with_rank",
4848
# Actor classes
4949
"GlobalLoggingActor",
5050
"LocalFetcherActor",

src/forge/observability/metric_actors.py

Lines changed: 81 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,17 @@
66

77
import asyncio
88
import logging
9+
import uuid
910
from typing import Any, Union
1011

11-
from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc
12+
from monarch.actor import (
13+
Actor,
14+
context,
15+
endpoint,
16+
get_or_spawn_controller,
17+
ProcMesh,
18+
this_proc,
19+
)
1220

1321
from forge.env import FORGE_DISABLE_METRICS
1422
from forge.observability.metrics import (
@@ -27,36 +35,35 @@
2735

2836
async def get_or_create_metric_logger(
2937
proc_mesh: ProcMesh | None = None,
38+
process_name: str | None = None,
3039
) -> "GlobalLoggingActor":
31-
"""Initializes a LocalFetcherActor in the specified process mesh (or current process if None),
32-
if not already initialized, registers it with the GlobalLoggingActor and returns the
33-
GlobalLoggingActor instance.
40+
"""Spawns a LocalFetcherActor for the specified ProcMesh (if not already initialized),
41+
registers it with the GlobalLoggingActor, and returns the GlobalLoggingActor.
3442
35-
There are primarily two ways to use this function:
36-
1. In the main process, call `get_or_create_metric_logger()` to get the global logger.
37-
2. In service processes, call `get_or_create_metric_logger(proc_mesh)` to register the
38-
local fetcher with the global logger.
43+
Usage:
44+
1. Main process: call `get_or_create_metric_logger()` to get the global logger
45+
2. Service spawning: call `get_or_create_metric_logger(proc_mesh, process_name)` to register the
46+
map(proc_mesh,local fetcher) with the global logger, so it knows to broadcast to all ranks.
3947
4048
Args:
41-
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None,
42-
uses `monarch.actor.this_proc()`.
49+
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `this_proc()`.
50+
process_name: Optional process name (e.g., "TrainActor") for logging. Auto-detected from the context if None.
4351
4452
Returns:
4553
GlobalLoggingActor: The global logging controller.
4654
4755
Raises:
48-
ValueError: If the logging state is inconsistent, i.e. the fetcher is already
49-
registered, but only in the process or the global logger.
56+
ValueError: If the logging state is inconsistent.
5057
5158
Example:
5259
from forge.observability.metric_actors import get_or_create_metric_logger
5360
from forge.observability.metrics import record_metric
5461
5562
# Main process setup
56-
mlogger = await get_or_create_metric_logger()
63+
mlogger = await get_or_create_metric_logger(process_name="Controller")
5764
5865
# Initialize logging backends
59-
await mlogger.init_backends({
66+
await mlogger.init_backends.call_one({
6067
"console": {"reduce_across_ranks": True},
6168
"wandb": {"project": "my_project", "reduce_across_ranks": False}
6269
})
@@ -66,12 +73,12 @@ async def get_or_create_metric_logger(
6673
6774
# Training loop
6875
for step in range(max_steps):
69-
record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN)
76+
record_metric("loss", 1.2, reduction_type=Reduce.MEAN)
7077
# ... training code with record_metric() calls ...
71-
await mlogger.flush(step) # Log metrics for this step
78+
await mlogger.flush.call_one(step) # Log metrics for this step
7279
7380
# Shutdown
74-
await mlogger.shutdown()
81+
await mlogger.shutdown.call_one()
7582
"""
7683
# Get or create the singleton global logger
7784
global _global_logger
@@ -85,9 +92,15 @@ async def get_or_create_metric_logger(
8592
# Determine process context
8693
proc = proc_mesh if proc_mesh is not None else this_proc()
8794

95+
# Auto-detect process_name from proc mesh if not provided
96+
if process_name is None:
97+
ctx = context()
98+
process_name = ctx.actor_instance.actor_id.actor_name
99+
88100
# Check current state for consistency
89101
proc_has_local_fetcher = hasattr(proc, "_local_fetcher")
90-
global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc)
102+
proc_id = proc._uid if proc_has_local_fetcher else None
103+
global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc_id)
91104

92105
# Consistency check: both should be in sync
93106
if proc_has_local_fetcher != global_logger_has_local_fetcher:
@@ -102,24 +115,32 @@ async def get_or_create_metric_logger(
102115
# Setup local_fetcher_actor if needed (unless disabled by environment flag)
103116
if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value():
104117
local_fetcher_actor = proc.spawn(
105-
"local_fetcher_actor", LocalFetcherActor, global_logger
118+
"local_fetcher_actor", LocalFetcherActor, global_logger, process_name
106119
)
107-
await global_logger.register_fetcher.call_one(local_fetcher_actor, proc)
120+
# Generate a unique ID to map procmesh to fetcher
121+
proc._uid = str(uuid.uuid4())
108122
proc._local_fetcher = local_fetcher_actor # pyre-ignore
109123

124+
await global_logger.register_fetcher.call_one(local_fetcher_actor, proc._uid)
125+
110126
return global_logger
111127

112128

113129
class LocalFetcherActor(Actor):
114-
"""Thin per-process actor used to trigger MetricCollector singleton
115-
operations without direct access. It is what GlobalLoggingActor
116-
uses to broadcast inits/flushes across ranks.
130+
"""Actor spawned once per ProcMesh that, when called, runs on every rank in that ProcMesh
131+
and accesses each rank's local MetricCollector.
117132
118-
GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
133+
Flow:
134+
GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
119135
"""
120136

121-
def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None:
137+
def __init__(
138+
self,
139+
global_logger: Union["GlobalLoggingActor", None] = None,
140+
process_name: str | None = None,
141+
) -> None:
122142
self.global_logger = global_logger
143+
self.process_name = process_name
123144
_is_initialized = False
124145

125146
@endpoint
@@ -146,10 +167,22 @@ async def init_backends(
146167
self,
147168
metadata_per_primary_backend: dict[str, dict[str, Any]],
148169
config: dict[str, Any],
170+
global_step: int = 0,
149171
) -> None:
150-
"""Init local (per-rank) logger backends and MetricCollector."""
172+
"""Init per-rank logger backends and MetricCollector.
173+
174+
Args:
175+
metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state.
176+
config (dict[str, Any]): Backend configurations with logging modes and settings.
177+
global_step (int): Initial step for metrics.
178+
"""
151179
collector = MetricCollector()
152-
await collector.init_backends(metadata_per_primary_backend, config)
180+
await collector.init_backends(
181+
metadata_per_primary_backend,
182+
config,
183+
global_step,
184+
process_name=self.process_name,
185+
)
153186

154187
@endpoint
155188
async def shutdown(self) -> None:
@@ -158,22 +191,17 @@ async def shutdown(self) -> None:
158191

159192

160193
class GlobalLoggingActor(Actor):
161-
"""Coordinates metric logging across all ranks for every training step.
194+
"""Coordinates metric logging across all ProcMeshes and their ranks.
162195
163196
Supports multiple logging backends (e.g., WandB, TensorBoard, etc.),
164-
for per-rank and/or global reduction logging modes.
197+
with per-rank and/or global reduction logging modes.
165198
166199
If a backend config has flag `reduce_across_ranks=False`, an instance of the backend
167200
is initialized per-rank, otherwise it is done once globally.
168201
169-
This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor
170-
is automatically spawned per-rank in `forge.controller.provisioner.py` and registered
171-
with this actor. The LocalFetcherActor is responsible for instantiating
172-
the per-rank MetricCollector.
173202
174-
In summary, the flow is:
175-
- GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector
176-
- GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush
203+
Flow:
204+
GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
177205
"""
178206

179207
def __init__(self):
@@ -209,7 +237,7 @@ async def init_backends(self, config: dict[str, Any]) -> None:
209237

210238
for backend_name, backend_config in config.items():
211239
backend = get_logger_backend_class(backend_name)(backend_config)
212-
await backend.init(role=BackendRole.GLOBAL)
240+
await backend.init(role=BackendRole.GLOBAL, name="global_reduce")
213241

214242
# Extract metadata from primary logger to be shared with secondary loggers
215243
# and store it
@@ -237,30 +265,31 @@ async def init_backends(self, config: dict[str, Any]) -> None:
237265
await asyncio.gather(*tasks, return_exceptions=True)
238266

239267
@endpoint
240-
async def register_fetcher(
241-
self, fetcher: LocalFetcherActor, name: str | ProcMesh
242-
) -> None:
243-
"""Registers a fetcher with the global actor. Each key represents a process mesh.
244-
If there are 2 processes, each with 2 replicas with N gpus, we would
245-
have 4 keys, i.e. 2 proces meshes, each with 2 replicas."""
246-
self.fetchers[name] = fetcher # pyre-ignore
268+
async def register_fetcher(self, fetcher: LocalFetcherActor, proc_id: str) -> None:
269+
"""Registers a LocalFetcherActor with the GlobalLoggingActor. One LocalFetcherActor per ProcMesh.
270+
271+
Args:
272+
fetcher: The LocalFetcherActor instance for a ProcMesh
273+
proc_id: Unique identifier for the ProcMesh
274+
"""
275+
self.fetchers[proc_id] = fetcher
247276

248277
# Self-init for respawned actors
249278
if self.config:
250-
logger.debug(f"Initializing new LocalFetcherActor {name}")
279+
logger.debug(f"Initializing new LocalFetcherActor for proc_id={proc_id}")
251280
await fetcher.init_backends.call(
252281
self.metadata_per_primary_backend, self.config
253282
)
254283

255284
@endpoint
256-
async def deregister_fetcher(self, name: str | ProcMesh) -> None:
257-
if name not in self.fetchers:
285+
async def deregister_fetcher(self, proc_id: str) -> None:
286+
if proc_id not in self.fetchers:
258287
logger.warning(
259-
f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister."
288+
f"Fetcher {proc_id} not registered in GlobalLoggingActor. Cannot deregister."
260289
f"Available fetchers: {self.fetchers.keys()}"
261290
)
262291
return
263-
del self.fetchers[name]
292+
del self.fetchers[proc_id]
264293

265294
@endpoint
266295
async def flush(self, global_step: int) -> None:
@@ -333,9 +362,9 @@ async def flush(self, global_step: int) -> None:
333362
await logger_backend.log(reduced_metrics, global_step)
334363

335364
@endpoint
336-
def has_fetcher(self, name: str | ProcMesh) -> bool:
337-
"""Check if a fetcher is registered with the given name."""
338-
return name in self.fetchers
365+
def has_fetcher(self, proc_id: str) -> bool:
366+
"""Check if a fetcher is registered with the given proc_id."""
367+
return proc_id in self.fetchers
339368

340369
@endpoint
341370
def get_fetcher_count(self) -> int:

0 commit comments

Comments
 (0)