Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ async def main(cfg: DictConfig):
provisioner = await init_provisioner()

metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger(process_name="Controller")
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)

# ---- Setup services ---- #
Expand Down
2 changes: 1 addition & 1 deletion src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def bootstrap(env: dict[str, str]):
if not FORGE_DISABLE_METRICS.get_value():
from forge.observability.metric_actors import get_or_create_metric_logger

_ = await get_or_create_metric_logger(procs, process_name=mesh_name)
_ = await get_or_create_metric_logger(procs)
return procs

async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
Expand Down
6 changes: 3 additions & 3 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from .metrics import (
BackendRole,
ConsoleBackend,
get_actor_name_with_rank,
get_logger_backend_class,
LoggerBackend,
MaxAccumulator,
MeanAccumulator,
Expand All @@ -27,12 +29,12 @@
WandbBackend,
)
from .perf_tracker import trace, Tracer
from .utils import get_proc_name_with_rank

__all__ = [
# Main API functions
"record_metric",
"reduce_metrics_states",
"get_actor_name_with_rank",
"get_logger_backend_class",
"get_or_create_metric_logger",
# Performance tracking
Expand All @@ -43,8 +45,6 @@
"BackendRole",
# Enums
"Reduce",
# Utility functions
"get_proc_name_with_rank",
# Actor classes
"GlobalLoggingActor",
"LocalFetcherActor",
Expand Down
50 changes: 8 additions & 42 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,7 @@
import logging
from typing import Any, Union

from monarch.actor import (
Actor,
context,
endpoint,
get_or_spawn_controller,
ProcMesh,
this_proc,
)
from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc

from forge.env import FORGE_DISABLE_METRICS
from forge.observability.metrics import (
Expand All @@ -34,7 +27,6 @@

async def get_or_create_metric_logger(
proc_mesh: ProcMesh | None = None,
process_name: str | None = None,
) -> "GlobalLoggingActor":
"""Initializes a LocalFetcherActor in the specified process mesh (or current process if None),
if not already initialized, registers it with the GlobalLoggingActor and returns the
Expand All @@ -48,9 +40,6 @@ async def get_or_create_metric_logger(
Args:
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None,
uses `monarch.actor.this_proc()`.
process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging.
If None, will be auto-detected from the mesh_name provided during actor initialization or
a generic mesh name if one was not provided.

Returns:
GlobalLoggingActor: The global logging controller.
Expand All @@ -64,7 +53,7 @@ async def get_or_create_metric_logger(
from forge.observability.metrics import record_metric

# Main process setup
mlogger = await get_or_create_metric_logger(process_name="Controller")
mlogger = await get_or_create_metric_logger()

# Initialize logging backends
await mlogger.init_backends({
Expand All @@ -77,14 +66,13 @@ async def get_or_create_metric_logger(

# Training loop
for step in range(max_steps):
record_metric("loss", 1.2, reduction_type=Reduce.MEAN)
record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN)
# ... training code with record_metric() calls ...
await mlogger.flush(step) # Log metrics for this step

# Shutdown
await mlogger.shutdown()
"""

# Get or create the singleton global logger
global _global_logger
if _global_logger is None:
Expand All @@ -96,19 +84,14 @@ async def get_or_create_metric_logger(
# Determine process context
proc = proc_mesh if proc_mesh is not None else this_proc()

# Auto-detect process_name from proc mesh if not provided
if process_name is None:
ctx = context()
process_name = ctx.actor_instance.actor_id.actor_name

# Check current state for consistency
proc_has_local_fetcher = hasattr(proc, "_local_fetcher")
global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc)

# Consistency check: both should be in sync
if proc_has_local_fetcher != global_logger_has_local_fetcher:
raise ValueError(
f"Inconsistent logging state for {proc=} with {process_name=}: "
f"Inconsistent logging state for proc {proc}: "
f"proc has _local_fetcher={proc_has_local_fetcher}, "
f"but global_logger has registration={global_logger_has_local_fetcher}. "
f"This indicates a bug in logging setup/teardown. "
Expand All @@ -118,7 +101,7 @@ async def get_or_create_metric_logger(
# Setup local_fetcher_actor if needed (unless disabled by environment flag)
if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value():
local_fetcher_actor = proc.spawn(
"local_fetcher_actor", LocalFetcherActor, global_logger, process_name
"local_fetcher_actor", LocalFetcherActor, global_logger
)
await global_logger.register_fetcher.call_one(local_fetcher_actor, proc)
proc._local_fetcher = local_fetcher_actor # pyre-ignore
Expand All @@ -134,13 +117,8 @@ class LocalFetcherActor(Actor):
GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
"""

def __init__(
self,
global_logger: Union["GlobalLoggingActor", None] = None,
process_name: str | None = None,
) -> None:
def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None:
self.global_logger = global_logger
self.process_name = process_name # Passed to MetricCollector for logging
_is_initialized = False

@endpoint
Expand All @@ -167,22 +145,10 @@ async def init_backends(
self,
metadata_per_primary_backend: dict[str, dict[str, Any]],
config: dict[str, Any],
global_step: int = 0,
) -> None:
"""Init local (per-rank) logger backends and MetricCollector.

Args:
metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state.
config (dict[str, Any]): Backend configurations with logging modes and settings.
global_step (int): Initial step for metrics.
"""
"""Init local (per-rank) logger backends and MetricCollector."""
collector = MetricCollector()
await collector.init_backends(
metadata_per_primary_backend,
config,
global_step,
process_name=self.process_name,
)
await collector.init_backends(metadata_per_primary_backend, config)

@endpoint
async def shutdown(self) -> None:
Expand Down
Loading
Loading