Skip to content

Commit 0496edb

Browse files
felipemello1Felipe Mello
authored andcommitted
Metric Logging updates 4/N - better actor name (#351)
Co-authored-by: Felipe Mello <[email protected]>
1 parent a3bf0f7 commit 0496edb

File tree

13 files changed

+376
-103
lines changed

13 files changed

+376
-103
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ async def main(cfg: DictConfig):
305305
provisioner = await init_provisioner()
306306

307307
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
308-
mlogger = await get_or_create_metric_logger()
308+
mlogger = await get_or_create_metric_logger(process_name="Controller")
309309
await mlogger.init_backends.call_one(metric_logging_cfg)
310310

311311
# ---- Setup services ---- #

src/forge/controller/provisioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def bootstrap(env: dict[str, str]):
313313
if not FORGE_DISABLE_METRICS.get_value():
314314
from forge.observability.metric_actors import get_or_create_metric_logger
315315

316-
_ = await get_or_create_metric_logger(procs)
316+
_ = await get_or_create_metric_logger(procs, process_name=mesh_name)
317317
return procs
318318

319319
async def host_mesh_from_proc(self, proc_mesh: ProcMesh):

src/forge/observability/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
from .metrics import (
1313
BackendRole,
1414
ConsoleBackend,
15-
get_actor_name_with_rank,
16-
get_logger_backend_class,
1715
LoggerBackend,
1816
MaxAccumulator,
1917
MeanAccumulator,
@@ -29,12 +27,12 @@
2927
WandbBackend,
3028
)
3129
from .perf_tracker import trace, Tracer
30+
from .utils import get_proc_name_with_rank
3231

3332
__all__ = [
3433
# Main API functions
3534
"record_metric",
3635
"reduce_metrics_states",
37-
"get_actor_name_with_rank",
3836
"get_logger_backend_class",
3937
"get_or_create_metric_logger",
4038
# Performance tracking
@@ -45,6 +43,8 @@
4543
"BackendRole",
4644
# Enums
4745
"Reduce",
46+
# Utility functions
47+
"get_proc_name_with_rank",
4848
# Actor classes
4949
"GlobalLoggingActor",
5050
"LocalFetcherActor",

src/forge/observability/metric_actors.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
import logging
99
from typing import Any, Union
1010

11-
from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc
11+
from monarch.actor import (
12+
Actor,
13+
context,
14+
endpoint,
15+
get_or_spawn_controller,
16+
ProcMesh,
17+
this_proc,
18+
)
1219

1320
from forge.env import FORGE_DISABLE_METRICS
1421
from forge.observability.metrics import (
@@ -27,6 +34,7 @@
2734

2835
async def get_or_create_metric_logger(
2936
proc_mesh: ProcMesh | None = None,
37+
process_name: str | None = None,
3038
) -> "GlobalLoggingActor":
3139
"""Initializes a LocalFetcherActor in the specified process mesh (or current process if None),
3240
if not already initialized, registers it with the GlobalLoggingActor and returns the
@@ -40,6 +48,9 @@ async def get_or_create_metric_logger(
4048
Args:
4149
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None,
4250
uses `monarch.actor.this_proc()`.
51+
process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging.
52+
If None, will be auto-detected from the mesh_name provided during actor initialization or
53+
a generic mesh name if one was not provided.
4354
4455
Returns:
4556
GlobalLoggingActor: The global logging controller.
@@ -53,7 +64,7 @@ async def get_or_create_metric_logger(
5364
from forge.observability.metrics import record_metric
5465
5566
# Main process setup
56-
mlogger = await get_or_create_metric_logger()
67+
mlogger = await get_or_create_metric_logger(process_name="Controller")
5768
5869
# Initialize logging backends
5970
await mlogger.init_backends({
@@ -66,13 +77,14 @@ async def get_or_create_metric_logger(
6677
6778
# Training loop
6879
for step in range(max_steps):
69-
record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN)
80+
record_metric("loss", 1.2, reduction_type=Reduce.MEAN)
7081
# ... training code with record_metric() calls ...
7182
await mlogger.flush(step) # Log metrics for this step
7283
7384
# Shutdown
7485
await mlogger.shutdown()
7586
"""
87+
7688
# Get or create the singleton global logger
7789
global _global_logger
7890
if _global_logger is None:
@@ -84,14 +96,19 @@ async def get_or_create_metric_logger(
8496
# Determine process context
8597
proc = proc_mesh if proc_mesh is not None else this_proc()
8698

99+
# Auto-detect process_name from proc mesh if not provided
100+
if process_name is None:
101+
ctx = context()
102+
process_name = ctx.actor_instance.actor_id.actor_name
103+
87104
# Check current state for consistency
88105
proc_has_local_fetcher = hasattr(proc, "_local_fetcher")
89106
global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc)
90107

91108
# Consistency check: both should be in sync
92109
if proc_has_local_fetcher != global_logger_has_local_fetcher:
93110
raise ValueError(
94-
f"Inconsistent logging state for proc {proc}: "
111+
f"Inconsistent logging state for {proc=} with {process_name=}: "
95112
f"proc has _local_fetcher={proc_has_local_fetcher}, "
96113
f"but global_logger has registration={global_logger_has_local_fetcher}. "
97114
f"This indicates a bug in logging setup/teardown. "
@@ -101,7 +118,7 @@ async def get_or_create_metric_logger(
101118
# Setup local_fetcher_actor if needed (unless disabled by environment flag)
102119
if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value():
103120
local_fetcher_actor = proc.spawn(
104-
"local_fetcher_actor", LocalFetcherActor, global_logger
121+
"local_fetcher_actor", LocalFetcherActor, global_logger, process_name
105122
)
106123
await global_logger.register_fetcher.call_one(local_fetcher_actor, proc)
107124
proc._local_fetcher = local_fetcher_actor # pyre-ignore
@@ -117,8 +134,13 @@ class LocalFetcherActor(Actor):
117134
GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
118135
"""
119136

120-
def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None:
137+
def __init__(
138+
self,
139+
global_logger: Union["GlobalLoggingActor", None] = None,
140+
process_name: str | None = None,
141+
) -> None:
121142
self.global_logger = global_logger
143+
self.process_name = process_name # Passed to MetricCollector for logging
122144
_is_initialized = False
123145

124146
@endpoint
@@ -145,10 +167,22 @@ async def init_backends(
145167
self,
146168
metadata_per_primary_backend: dict[str, dict[str, Any]],
147169
config: dict[str, Any],
170+
global_step: int = 0,
148171
) -> None:
149-
"""Init local (per-rank) logger backends and MetricCollector."""
172+
"""Init local (per-rank) logger backends and MetricCollector.
173+
174+
Args:
175+
metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state.
176+
config (dict[str, Any]): Backend configurations with logging modes and settings.
177+
global_step (int): Initial step for metrics.
178+
"""
150179
collector = MetricCollector()
151-
await collector.init_backends(metadata_per_primary_backend, config)
180+
await collector.init_backends(
181+
metadata_per_primary_backend,
182+
config,
183+
global_step,
184+
process_name=self.process_name,
185+
)
152186

153187
@endpoint
154188
async def shutdown(self) -> None:

0 commit comments

Comments
 (0)