Skip to content

Commit fffcb88

Browse files
committed
Revert "Metric Logging updates 4/N - better actor name (#351)"
This reverts commit 1f45470.
1 parent 25d8a7a commit fffcb88

File tree

13 files changed

+103
-376
lines changed

13 files changed

+103
-376
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ async def main(cfg: DictConfig):
305305
provisioner = await init_provisioner()
306306

307307
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
308-
mlogger = await get_or_create_metric_logger(process_name="Controller")
308+
mlogger = await get_or_create_metric_logger()
309309
await mlogger.init_backends.call_one(metric_logging_cfg)
310310

311311
# ---- Setup services ---- #

src/forge/controller/provisioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def bootstrap(env: dict[str, str]):
305305
if not FORGE_DISABLE_METRICS.get_value():
306306
from forge.observability.metric_actors import get_or_create_metric_logger
307307

308-
_ = await get_or_create_metric_logger(procs, process_name=mesh_name)
308+
_ = await get_or_create_metric_logger(procs)
309309
return procs
310310

311311
async def host_mesh_from_proc(self, proc_mesh: ProcMesh):

src/forge/observability/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from .metrics import (
1313
BackendRole,
1414
ConsoleBackend,
15+
get_actor_name_with_rank,
16+
get_logger_backend_class,
1517
LoggerBackend,
1618
MaxAccumulator,
1719
MeanAccumulator,
@@ -27,12 +29,12 @@
2729
WandbBackend,
2830
)
2931
from .perf_tracker import trace, Tracer
30-
from .utils import get_proc_name_with_rank
3132

3233
__all__ = [
3334
# Main API functions
3435
"record_metric",
3536
"reduce_metrics_states",
37+
"get_actor_name_with_rank",
3638
"get_logger_backend_class",
3739
"get_or_create_metric_logger",
3840
# Performance tracking
@@ -43,8 +45,6 @@
4345
"BackendRole",
4446
# Enums
4547
"Reduce",
46-
# Utility functions
47-
"get_proc_name_with_rank",
4848
# Actor classes
4949
"GlobalLoggingActor",
5050
"LocalFetcherActor",

src/forge/observability/metric_actors.py

Lines changed: 8 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,7 @@
88
import logging
99
from typing import Any, Union
1010

11-
from monarch.actor import (
12-
Actor,
13-
context,
14-
endpoint,
15-
get_or_spawn_controller,
16-
ProcMesh,
17-
this_proc,
18-
)
11+
from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc
1912

2013
from forge.env import FORGE_DISABLE_METRICS
2114
from forge.observability.metrics import (
@@ -34,7 +27,6 @@
3427

3528
async def get_or_create_metric_logger(
3629
proc_mesh: ProcMesh | None = None,
37-
process_name: str | None = None,
3830
) -> "GlobalLoggingActor":
3931
"""Initializes a LocalFetcherActor in the specified process mesh (or current process if None),
4032
if not already initialized, registers it with the GlobalLoggingActor and returns the
@@ -48,9 +40,6 @@ async def get_or_create_metric_logger(
4840
Args:
4941
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None,
5042
uses `monarch.actor.this_proc()`.
51-
process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging.
52-
If None, will be auto-detected from the mesh_name provided during actor initialization or
53-
a generic mesh name if one was not provided.
5443
5544
Returns:
5645
GlobalLoggingActor: The global logging controller.
@@ -64,7 +53,7 @@ async def get_or_create_metric_logger(
6453
from forge.observability.metrics import record_metric
6554
6655
# Main process setup
67-
mlogger = await get_or_create_metric_logger(process_name="Controller")
56+
mlogger = await get_or_create_metric_logger()
6857
6958
# Initialize logging backends
7059
await mlogger.init_backends({
@@ -77,14 +66,13 @@ async def get_or_create_metric_logger(
7766
7867
# Training loop
7968
for step in range(max_steps):
80-
record_metric("loss", 1.2, reduction_type=Reduce.MEAN)
69+
record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN)
8170
# ... training code with record_metric() calls ...
8271
await mlogger.flush(step) # Log metrics for this step
8372
8473
# Shutdown
8574
await mlogger.shutdown()
8675
"""
87-
8876
# Get or create the singleton global logger
8977
global _global_logger
9078
if _global_logger is None:
@@ -96,19 +84,14 @@ async def get_or_create_metric_logger(
9684
# Determine process context
9785
proc = proc_mesh if proc_mesh is not None else this_proc()
9886

99-
# Auto-detect process_name from proc mesh if not provided
100-
if process_name is None:
101-
ctx = context()
102-
process_name = ctx.actor_instance.actor_id.actor_name
103-
10487
# Check current state for consistency
10588
proc_has_local_fetcher = hasattr(proc, "_local_fetcher")
10689
global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc)
10790

10891
# Consistency check: both should be in sync
10992
if proc_has_local_fetcher != global_logger_has_local_fetcher:
11093
raise ValueError(
111-
f"Inconsistent logging state for {proc=} with {process_name=}: "
94+
f"Inconsistent logging state for proc {proc}: "
11295
f"proc has _local_fetcher={proc_has_local_fetcher}, "
11396
f"but global_logger has registration={global_logger_has_local_fetcher}. "
11497
f"This indicates a bug in logging setup/teardown. "
@@ -118,7 +101,7 @@ async def get_or_create_metric_logger(
118101
# Setup local_fetcher_actor if needed (unless disabled by environment flag)
119102
if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value():
120103
local_fetcher_actor = proc.spawn(
121-
"local_fetcher_actor", LocalFetcherActor, global_logger, process_name
104+
"local_fetcher_actor", LocalFetcherActor, global_logger
122105
)
123106
await global_logger.register_fetcher.call_one(local_fetcher_actor, proc)
124107
proc._local_fetcher = local_fetcher_actor # pyre-ignore
@@ -134,13 +117,8 @@ class LocalFetcherActor(Actor):
134117
GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
135118
"""
136119

137-
def __init__(
138-
self,
139-
global_logger: Union["GlobalLoggingActor", None] = None,
140-
process_name: str | None = None,
141-
) -> None:
120+
def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None:
142121
self.global_logger = global_logger
143-
self.process_name = process_name # Passed to MetricCollector for logging
144122
_is_initialized = False
145123

146124
@endpoint
@@ -167,22 +145,10 @@ async def init_backends(
167145
self,
168146
metadata_per_primary_backend: dict[str, dict[str, Any]],
169147
config: dict[str, Any],
170-
global_step: int = 0,
171148
) -> None:
172-
"""Init local (per-rank) logger backends and MetricCollector.
173-
174-
Args:
175-
metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state.
176-
config (dict[str, Any]): Backend configurations with logging modes and settings.
177-
global_step (int): Initial step for metrics.
178-
"""
149+
"""Init local (per-rank) logger backends and MetricCollector."""
179150
collector = MetricCollector()
180-
await collector.init_backends(
181-
metadata_per_primary_backend,
182-
config,
183-
global_step,
184-
process_name=self.process_name,
185-
)
151+
await collector.init_backends(metadata_per_primary_backend, config)
186152

187153
@endpoint
188154
async def shutdown(self) -> None:

0 commit comments

Comments
 (0)