88import logging
99from typing import Any , Union
1010
11- from monarch .actor import Actor , endpoint , get_or_spawn_controller , ProcMesh , this_proc
11+ from monarch .actor import (
12+ Actor ,
13+ context ,
14+ endpoint ,
15+ get_or_spawn_controller ,
16+ ProcMesh ,
17+ this_proc ,
18+ )
1219
1320from forge .env import FORGE_DISABLE_METRICS
1421from forge .observability .metrics import (
2734
2835async def get_or_create_metric_logger (
2936 proc_mesh : ProcMesh | None = None ,
37+ process_name : str | None = None ,
3038) -> "GlobalLoggingActor" :
3139 """Initializes a LocalFetcherActor in the specified process mesh (or current process if None),
3240 if not already initialized, registers it with the GlobalLoggingActor and returns the
@@ -40,6 +48,9 @@ async def get_or_create_metric_logger(
4048 Args:
4149 proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None,
4250 uses `monarch.actor.this_proc()`.
51+ process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging.
52+ If None, will be auto-detected from the mesh_name provided during actor initialization or
53+ a generic mesh name if one was not provided.
4354
4455 Returns:
4556 GlobalLoggingActor: The global logging controller.
@@ -53,7 +64,7 @@ async def get_or_create_metric_logger(
5364 from forge.observability.metrics import record_metric
5465
5566 # Main process setup
56- mlogger = await get_or_create_metric_logger()
67+ mlogger = await get_or_create_metric_logger(process_name="Controller" )
5768
5869 # Initialize logging backends
5970 await mlogger.init_backends({
@@ -66,13 +77,14 @@ async def get_or_create_metric_logger(
6677
6778 # Training loop
6879 for step in range(max_steps):
69- record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN)
80+ record_metric("loss", 1.2, reduction_type=Reduce.MEAN)
7081 # ... training code with record_metric() calls ...
7182 await mlogger.flush(step) # Log metrics for this step
7283
7384 # Shutdown
7485 await mlogger.shutdown()
7586 """
87+
7688 # Get or create the singleton global logger
7789 global _global_logger
7890 if _global_logger is None :
@@ -84,14 +96,19 @@ async def get_or_create_metric_logger(
8496 # Determine process context
8597 proc = proc_mesh if proc_mesh is not None else this_proc ()
8698
99+ # Auto-detect process_name from proc mesh if not provided
100+ if process_name is None :
101+ ctx = context ()
102+ process_name = ctx .actor_instance .actor_id .actor_name
103+
87104 # Check current state for consistency
88105 proc_has_local_fetcher = hasattr (proc , "_local_fetcher" )
89106 global_logger_has_local_fetcher = await global_logger .has_fetcher .call_one (proc )
90107
91108 # Consistency check: both should be in sync
92109 if proc_has_local_fetcher != global_logger_has_local_fetcher :
93110 raise ValueError (
94- f"Inconsistent logging state for proc { proc } : "
111+ f"Inconsistent logging state for { proc = } with { process_name = } : "
95112 f"proc has _local_fetcher={ proc_has_local_fetcher } , "
96113 f"but global_logger has registration={ global_logger_has_local_fetcher } . "
97114 f"This indicates a bug in logging setup/teardown. "
@@ -101,7 +118,7 @@ async def get_or_create_metric_logger(
101118 # Setup local_fetcher_actor if needed (unless disabled by environment flag)
102119 if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS .get_value ():
103120 local_fetcher_actor = proc .spawn (
104- "local_fetcher_actor" , LocalFetcherActor , global_logger
121+ "local_fetcher_actor" , LocalFetcherActor , global_logger , process_name
105122 )
106123 await global_logger .register_fetcher .call_one (local_fetcher_actor , proc )
107124 proc ._local_fetcher = local_fetcher_actor # pyre-ignore
@@ -117,8 +134,13 @@ class LocalFetcherActor(Actor):
117134 GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
118135 """
119136
120- def __init__ (self , global_logger : Union ["GlobalLoggingActor" , None ] = None ) -> None :
137+ def __init__ (
138+ self ,
139+ global_logger : Union ["GlobalLoggingActor" , None ] = None ,
140+ process_name : str | None = None ,
141+ ) -> None :
121142 self .global_logger = global_logger
143+ self .process_name = process_name # Passed to MetricCollector for logging
122144 _is_initialized = False
123145
124146 @endpoint
@@ -145,10 +167,22 @@ async def init_backends(
145167 self ,
146168 metadata_per_primary_backend : dict [str , dict [str , Any ]],
147169 config : dict [str , Any ],
170+ global_step : int = 0 ,
148171 ) -> None :
149- """Init local (per-rank) logger backends and MetricCollector."""
172+ """Init local (per-rank) logger backends and MetricCollector.
173+
174+ Args:
175+ metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state.
176+ config (dict[str, Any]): Backend configurations with logging modes and settings.
177+ global_step (int): Initial step for metrics.
178+ """
150179 collector = MetricCollector ()
151- await collector .init_backends (metadata_per_primary_backend , config )
180+ await collector .init_backends (
181+ metadata_per_primary_backend ,
182+ config ,
183+ global_step ,
184+ process_name = self .process_name ,
185+ )
152186
153187 @endpoint
154188 async def shutdown (self ) -> None :
0 commit comments