66
77import asyncio
88import logging
9+ import uuid
910from typing import Any , Union
1011
11- from monarch .actor import Actor , endpoint , get_or_spawn_controller , ProcMesh , this_proc
12+ from monarch .actor import (
13+ Actor ,
14+ context ,
15+ endpoint ,
16+ get_or_spawn_controller ,
17+ ProcMesh ,
18+ this_proc ,
19+ )
1220
1321from forge .env import FORGE_DISABLE_METRICS
1422from forge .observability .metrics import (
2735
2836async def get_or_create_metric_logger (
2937 proc_mesh : ProcMesh | None = None ,
38+ process_name : str | None = None ,
3039) -> "GlobalLoggingActor" :
31- """Initializes a LocalFetcherActor in the specified process mesh (or current process if None),
32- if not already initialized, registers it with the GlobalLoggingActor and returns the
33- GlobalLoggingActor instance.
40+ """Spawns a LocalFetcherActor for the specified ProcMesh (if not already initialized),
41+ registers it with the GlobalLoggingActor, and returns the GlobalLoggingActor.
3442
35- There are primarily two ways to use this function :
36- 1. In the main process, call `get_or_create_metric_logger()` to get the global logger.
37- 2. In service processes, call `get_or_create_metric_logger(proc_mesh)` to register the
38- local fetcher with the global logger.
43+ Usage :
44+ 1. Main process: call `get_or_create_metric_logger()` to get the global logger
45+ 2. Service spawning: call `get_or_create_metric_logger(proc_mesh, process_name )` to register the
46+ map(proc_mesh, local fetcher) with the global logger, so it knows to broadcast to all ranks .
3947
4048 Args:
41- proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None,
42- uses `monarch.actor.this_proc()` .
49+ proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `this_proc()`.
50+ process_name: Optional process name (e.g., "TrainActor") for logging. Auto-detected from the context if None .
4351
4452 Returns:
4553 GlobalLoggingActor: The global logging controller.
4654
4755 Raises:
48- ValueError: If the logging state is inconsistent, i.e. the fetcher is already
49- registered, but only in the process or the global logger.
56+ ValueError: If the logging state is inconsistent.
5057
5158 Example:
5259 from forge.observability.metric_actors import get_or_create_metric_logger
5360 from forge.observability.metrics import record_metric
5461
5562 # Main process setup
56- mlogger = await get_or_create_metric_logger()
63+ mlogger = await get_or_create_metric_logger(process_name="Controller" )
5764
5865 # Initialize logging backends
59- await mlogger.init_backends({
66+ await mlogger.init_backends.call_one ({
6067 "console": {"reduce_across_ranks": True},
6168 "wandb": {"project": "my_project", "reduce_across_ranks": False}
6269 })
@@ -66,12 +73,12 @@ async def get_or_create_metric_logger(
6673
6774 # Training loop
6875 for step in range(max_steps):
69- record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN)
76+ record_metric("loss", 1.2, reduction_type=Reduce.MEAN)
7077 # ... training code with record_metric() calls ...
71- await mlogger.flush(step) # Log metrics for this step
78+ await mlogger.flush.call_one (step) # Log metrics for this step
7279
7380 # Shutdown
74- await mlogger.shutdown()
81+ await mlogger.shutdown.call_one ()
7582 """
7683 # Get or create the singleton global logger
7784 global _global_logger
@@ -85,9 +92,15 @@ async def get_or_create_metric_logger(
8592 # Determine process context
8693 proc = proc_mesh if proc_mesh is not None else this_proc ()
8794
95+ # Auto-detect process_name from proc mesh if not provided
96+ if process_name is None :
97+ ctx = context ()
98+ process_name = ctx .actor_instance .actor_id .actor_name
99+
88100 # Check current state for consistency
89101 proc_has_local_fetcher = hasattr (proc , "_local_fetcher" )
90- global_logger_has_local_fetcher = await global_logger .has_fetcher .call_one (proc )
102+ proc_id = proc ._uid if proc_has_local_fetcher else None
103+ global_logger_has_local_fetcher = await global_logger .has_fetcher .call_one (proc_id )
91104
92105 # Consistency check: both should be in sync
93106 if proc_has_local_fetcher != global_logger_has_local_fetcher :
@@ -102,24 +115,32 @@ async def get_or_create_metric_logger(
102115 # Setup local_fetcher_actor if needed (unless disabled by environment flag)
103116 if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS .get_value ():
104117 local_fetcher_actor = proc .spawn (
105- "local_fetcher_actor" , LocalFetcherActor , global_logger
118+ "local_fetcher_actor" , LocalFetcherActor , global_logger , process_name
106119 )
107- await global_logger .register_fetcher .call_one (local_fetcher_actor , proc )
120+ # Generate a unique ID to map procmesh to fetcher
121+ proc ._uid = str (uuid .uuid4 ())
108122 proc ._local_fetcher = local_fetcher_actor # pyre-ignore
109123
124+ await global_logger .register_fetcher .call_one (local_fetcher_actor , proc ._uid )
125+
110126 return global_logger
111127
112128
113129class LocalFetcherActor (Actor ):
114- """Thin per-process actor used to trigger MetricCollector singleton
115- operations without direct access. It is what GlobalLoggingActor
116- uses to broadcast inits/flushes across ranks.
130+ """Actor spawned once per ProcMesh that, when called, runs on every rank in that ProcMesh
131+ and accesses each rank's local MetricCollector.
117132
118- GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
133+ Flow:
134+ GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
119135 """
120136
121- def __init__ (self , global_logger : Union ["GlobalLoggingActor" , None ] = None ) -> None :
137+ def __init__ (
138+ self ,
139+ global_logger : Union ["GlobalLoggingActor" , None ] = None ,
140+ process_name : str | None = None ,
141+ ) -> None :
122142 self .global_logger = global_logger
143+ self .process_name = process_name
123144 _is_initialized = False
124145
125146 @endpoint
@@ -146,10 +167,22 @@ async def init_backends(
146167 self ,
147168 metadata_per_primary_backend : dict [str , dict [str , Any ]],
148169 config : dict [str , Any ],
170+ global_step : int = 0 ,
149171 ) -> None :
150- """Init local (per-rank) logger backends and MetricCollector."""
172+ """Init per-rank logger backends and MetricCollector.
173+
174+ Args:
175+ metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state.
176+ config (dict[str, Any]): Backend configurations with logging modes and settings.
177+ global_step (int): Initial step for metrics.
178+ """
151179 collector = MetricCollector ()
152- await collector .init_backends (metadata_per_primary_backend , config )
180+ await collector .init_backends (
181+ metadata_per_primary_backend ,
182+ config ,
183+ global_step ,
184+ process_name = self .process_name ,
185+ )
153186
154187 @endpoint
155188 async def shutdown (self ) -> None :
@@ -158,22 +191,17 @@ async def shutdown(self) -> None:
158191
159192
160193class GlobalLoggingActor (Actor ):
161- """Coordinates metric logging across all ranks for every training step .
194+ """Coordinates metric logging across all ProcMeshes and their ranks .
162195
163196 Supports multiple logging backends (e.g., WandB, TensorBoard, etc.),
164- for per-rank and/or global reduction logging modes.
197+ with per-rank and/or global reduction logging modes.
165198
166199 If a backend config has flag `reduce_across_ranks=False`, an instance of the backend
167200 is initialized per-rank, otherwise it is done once globally.
168201
169- This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor
170- is automatically spawned per-rank in `forge.controller.provisioner.py` and registered
171- with this actor. The LocalFetcherActor is responsible for instantiating
172- the per-rank MetricCollector.
173202
174- In summary, the flow is:
175- - GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector
176- - GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush
203+ Flow:
204+ GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
177205 """
178206
179207 def __init__ (self ):
@@ -209,7 +237,7 @@ async def init_backends(self, config: dict[str, Any]) -> None:
209237
210238 for backend_name , backend_config in config .items ():
211239 backend = get_logger_backend_class (backend_name )(backend_config )
212- await backend .init (role = BackendRole .GLOBAL )
240+ await backend .init (role = BackendRole .GLOBAL , name = "global_reduce" )
213241
214242 # Extract metadata from primary logger to be shared with secondary loggers
215243 # and store it
@@ -237,30 +265,31 @@ async def init_backends(self, config: dict[str, Any]) -> None:
237265 await asyncio .gather (* tasks , return_exceptions = True )
238266
239267 @endpoint
240- async def register_fetcher (
241- self , fetcher : LocalFetcherActor , name : str | ProcMesh
242- ) -> None :
243- """Registers a fetcher with the global actor. Each key represents a process mesh.
244- If there are 2 processes, each with 2 replicas with N gpus, we would
245- have 4 keys, i.e. 2 proces meshes, each with 2 replicas."""
246- self .fetchers [name ] = fetcher # pyre-ignore
268+ async def register_fetcher (self , fetcher : LocalFetcherActor , proc_id : str ) -> None :
269+ """Registers a LocalFetcherActor with the GlobalLoggingActor. One LocalFetcherActor per ProcMesh.
270+
271+ Args:
272+ fetcher: The LocalFetcherActor instance for a ProcMesh
273+ proc_id: Unique identifier for the ProcMesh
274+ """
275+ self .fetchers [proc_id ] = fetcher
247276
248277 # Self-init for respawned actors
249278 if self .config :
250- logger .debug (f"Initializing new LocalFetcherActor { name } " )
279+ logger .debug (f"Initializing new LocalFetcherActor for proc_id= { proc_id } " )
251280 await fetcher .init_backends .call (
252281 self .metadata_per_primary_backend , self .config
253282 )
254283
255284 @endpoint
256- async def deregister_fetcher (self , name : str | ProcMesh ) -> None :
257- if name not in self .fetchers :
285+ async def deregister_fetcher (self , proc_id : str ) -> None :
286+ if proc_id not in self .fetchers :
258287 logger .warning (
259- f"Fetcher { name } not registered in GlobalLoggingActor. Cannot deregister."
288+ f"Fetcher { proc_id } not registered in GlobalLoggingActor. Cannot deregister."
260289 f"Available fetchers: { self .fetchers .keys ()} "
261290 )
262291 return
263- del self .fetchers [name ]
292+ del self .fetchers [proc_id ]
264293
265294 @endpoint
266295 async def flush (self , global_step : int ) -> None :
@@ -333,9 +362,9 @@ async def flush(self, global_step: int) -> None:
333362 await logger_backend .log (reduced_metrics , global_step )
334363
335364 @endpoint
336- def has_fetcher (self , name : str | ProcMesh ) -> bool :
337- """Check if a fetcher is registered with the given name ."""
338- return name in self .fetchers
365+ def has_fetcher (self , proc_id : str ) -> bool :
366+ """Check if a fetcher is registered with the given proc_id ."""
367+ return proc_id in self .fetchers
339368
340369 @endpoint
341370 def get_fetcher_count (self ) -> int :
0 commit comments