From 77488cfe630be4390593ee16d6ca0d24dc67f93f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 08:38:55 -0700 Subject: [PATCH 01/18] commit --- src/forge/controller/provisioner.py | 3 +- src/forge/env_constants.py | 1 + src/forge/observability/metric_actors.py | 9 +- src/forge/observability/metrics.py | 37 ++- tests/unit_tests/observability/conftest.py | 95 +++++++ .../unit_tests/observability/test_metrics.py | 231 ++++++++++++++++++ 6 files changed, 370 insertions(+), 6 deletions(-) create mode 100644 tests/unit_tests/observability/conftest.py create mode 100644 tests/unit_tests/observability/test_metrics.py diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index c823afb29..429c5760f 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -262,7 +262,8 @@ def bootstrap(env: dict[str, str]): self._proc_host_map[procs] = host_mesh - # Spawn local logging actor on each process and register with global logger + # Spawn local fetcher actor on each process and register with global logger + # Can be disabled by FORGE_DISABLE_METRICS env var _ = await get_or_create_metric_logger(procs) return procs diff --git a/src/forge/env_constants.py b/src/forge/env_constants.py index 3adcdfc41..a4e024d83 100644 --- a/src/forge/env_constants.py +++ b/src/forge/env_constants.py @@ -14,4 +14,5 @@ METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA" # Makes forge.observability.metrics.record_metric a no-op +# and disables spawning LocalFetcherActor in get_or_create_metric_logger FORGE_DISABLE_METRICS = "FORGE_DISABLE_METRICS" diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index d67a66a83..edd1f24d8 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -6,10 +6,12 @@ import asyncio import logging +import os from typing import Any, Dict, Optional from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc +from forge.env_constants import FORGE_DISABLE_METRICS from forge.observability.metrics import ( get_logger_backend_class, LoggerBackend, @@ -95,8 +97,11 @@ async def get_or_create_metric_logger( f"Both should be True (already setup) or both False (needs setup)." ) - # Setup local_fetcher_actor if needed - if not proc_has_local_fetcher: + # Setup local_fetcher_actor if needed (unless disabled by environment flag) + if ( + not proc_has_local_fetcher + and os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true" + ): local_fetcher_actor = proc.spawn( "local_fetcher_actor", LocalFetcherActor, global_logger ) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 990a301e0..4d527ec1b 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -437,7 +437,21 @@ async def init_backends( def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: if not self._is_initialized: - raise ValueError("Collector not initialized—call init first") + from forge.util.logging import log_once + + log_once( + logger, + level=logging.WARNING, + msg=( + "Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized." + " This happens when you try to use `record_metric` before calling `init_backends`." + " To disable this warning, please call in your main file:\n" + "`mlogger = await get_or_create_metric_logger()`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" + "or set env variable `FORGE_DISABLE_METRICS=True`" + ), + ) + return if key not in self.accumulators: self.accumulators[key] = reduction.accumulator_class(reduction) @@ -458,8 +472,16 @@ async def flush( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: - logger.debug( - f"Collector not yet initialized for {get_actor_name_with_rank()}. Call init_backends first." + from forge.util.logging import log_once + + log_once( + logger, + level=logging.WARNING, + msg="Cannot flush collected metrics. MetricCollector.flush() called before init_backends()." + "\nPlease call in your main file:\n" + "`mlogger = await get_or_create_metric_logger()`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" + "before calling `flush`", ) return {} @@ -662,6 +684,15 @@ async def _init_shared_local(self, primary_metadata: Dict[str, Any]): raise ValueError( f"Shared ID required but not provided for {self.name} backend init" ) + + # Clear any stale service tokens that might be pointing to dead processes + # In multiprocessing environments, WandB service tokens can become stale and point + # to dead service processes. This causes wandb.init() to hang indefinitely trying + # to connect to non-existent services. Clearing forces fresh service connection. + from wandb.sdk.lib.service import service_token + + service_token.clear_service_in_env() + settings = wandb.Settings(mode="shared", x_primary=False, x_label=self.name) self.run = wandb.init( id=shared_id, diff --git a/tests/unit_tests/observability/conftest.py b/tests/unit_tests/observability/conftest.py new file mode 100644 index 000000000..a803c252d --- /dev/null +++ b/tests/unit_tests/observability/conftest.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared fixtures and mocks for observability unit tests.""" + +from unittest.mock import MagicMock, patch + +import pytest +from forge.observability.metrics import LoggerBackend, MetricCollector + + +class MockBackend(LoggerBackend): + """Mock backend for testing metrics logging without external dependencies.""" + + def __init__(self, logger_backend_config=None): + super().__init__(logger_backend_config or {}) + self.logged_metrics = [] + self.init_called = False + self.finish_called = False + self.metadata = {} + + async def init(self, role="local", primary_logger_metadata=None): + self.init_called = True + self.role = role + self.primary_logger_metadata = primary_logger_metadata or {} + + async def log(self, metrics, step): + self.logged_metrics.append((metrics, step)) + + async def finish(self): + self.finish_called = True + + def get_metadata_for_secondary_ranks(self): + return self.metadata + + +@pytest.fixture(autouse=True) +def clear_metric_collector_singletons(): + """Clear MetricCollector singletons before each test to avoid state leakage.""" + MetricCollector._instances.clear() + yield + MetricCollector._instances.clear() + + +@pytest.fixture(autouse=True) +def clean_metrics_environment(): + """Override the global mock_metrics_globally fixture to allow real metrics testing.""" + import os + + from forge.env_constants import FORGE_DISABLE_METRICS + + # Set default state for tests (metrics enabled) + if FORGE_DISABLE_METRICS in os.environ: + del os.environ[FORGE_DISABLE_METRICS] + + yield + + +@pytest.fixture +def mock_rank(): + """Mock current_rank function with configurable rank.""" + with patch("forge.observability.metrics.current_rank") as mock: + rank_obj = MagicMock() + rank_obj.rank = 0 + mock.return_value = rank_obj + yield mock + + +@pytest.fixture +def mock_actor_context(): + """Mock Monarch actor context for testing actor name generation.""" + with patch("forge.observability.metrics.context") as mock_context, patch( + "forge.observability.metrics.current_rank" + ) as mock_rank: + + # Setup mock context + ctx = MagicMock() + actor_instance = MagicMock() + actor_instance.actor_id = "_1rjutFUXQrEJ[0].TestActorConfigured[0]" + ctx.actor_instance = actor_instance + mock_context.return_value = ctx + + # Setup mock rank + rank_obj = MagicMock() + rank_obj.rank = 0 + mock_rank.return_value = rank_obj + + yield { + "context": mock_context, + "rank": mock_rank, + "expected_name": "TestActor_0XQr_r0", + } diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py new file mode 100644 index 000000000..3e864bdf7 --- /dev/null +++ b/tests/unit_tests/observability/test_metrics.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for core metrics functionality focusing on critical fixes in Diff 1.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.observability.metrics import ( + ConsoleBackend, + get_logger_backend_class, + MeanAccumulator, + MetricCollector, + record_metric, + Reduce, + WandbBackend, +) + + +class TestCriticalFixes: + """Test critical production fixes from Diff 1.""" + + def test_uninitialized_push_logs_warning(self, mock_rank, caplog): + """Test MetricCollector.push() logs warning when uninitialized.""" + collector = MetricCollector() + + # Should not raise error, just log warning and return + collector.push("test", 1.0, Reduce.MEAN) + assert any( + "Metric logging backends" in record.message for record in caplog.records + ) + + @pytest.mark.asyncio + async def test_uninitialized_flush_logs_warning(self, mock_rank, caplog): + """Test MetricCollector.flush() logs warning when uninitialized.""" + collector = MetricCollector() + + # Should not raise error, just log warning and return empty dict + result = await collector.flush(step=1, return_state=True) + assert result == {} + assert any( + "Cannot flush collected metrics" in record.message + for record in caplog.records + ) + + @patch.dict("os.environ", {"FORGE_DISABLE_METRICS": "true"}) + @patch("forge.observability.metrics.MetricCollector") + def test_record_metric_disabled(self, mock_collector_class): + """Test record_metric is no-op when FORGE_DISABLE_METRICS=true.""" + record_metric("loss", 1.5, Reduce.MEAN) + mock_collector_class.assert_not_called() + + @patch.dict("os.environ", {"FORGE_DISABLE_METRICS": "false"}) + @patch("forge.observability.metrics.MetricCollector") + def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank): + """Test record_metric works when FORGE_DISABLE_METRICS=false.""" + mock_collector = MagicMock() + mock_collector_class.return_value = mock_collector + + record_metric("loss", 1.5, Reduce.MEAN) + mock_collector_class.assert_called_once() + mock_collector.push.assert_called_once() + + @patch("forge.observability.metrics.get_actor_name_with_rank") + def test_wandb_backend_creation(self, mock_actor_name): + """Test WandbBackend creation and basic setup without WandB dependency.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + config = { + "project": "test_project", + "group": "test_group", + "reduce_across_ranks": True, + } + backend = WandbBackend(config) + + assert backend.project == "test_project" + assert backend.group == "test_group" + assert backend.reduce_across_ranks is True + assert backend.share_run_id is False # default + + # Test metadata method + metadata = backend.get_metadata_for_secondary_ranks() + assert metadata == {} # Should be empty when no run + + @patch("forge.observability.metrics.get_actor_name_with_rank") + @pytest.mark.asyncio + async def test_console_backend(self, mock_actor_name): + """Test ConsoleBackend basic operations.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + backend = ConsoleBackend({}) + + await backend.init(role="local") + + # Test log - should not raise + await backend.log({"test": 1.0}, step=1) + + await backend.finish() # Should not raise + + +class TestBasicAccumulators: + """Test basic accumulator functionality.""" + + def test_mean_accumulator(self): + """Test MeanAccumulator operations.""" + acc = MeanAccumulator(Reduce.MEAN) + + # Test initial state + assert acc.get_value() == 0.0 + state = acc.get_state() + assert state["sum"] == 0.0 + assert state["count"] == 0 + + # Test append and get_value + acc.append(10.0) + acc.append(20.0) + assert acc.get_value() == 15.0 + + # Test state + state = acc.get_state() + assert state["sum"] == 30.0 + assert state["count"] == 2 + assert state["reduction_type"] == "mean" + + # Test reset + acc.reset() + assert acc.get_value() == 0.0 + assert acc.get_state()["sum"] == 0.0 + assert acc.get_state()["count"] == 0 + + def test_reduce_enum_accumulator_mapping(self): + """Test that Reduce enum correctly maps to accumulator classes.""" + assert Reduce.MEAN.accumulator_class == MeanAccumulator + + +class TestBackendFactory: + """Test backend factory function.""" + + def test_backend_factory(self): + """Test get_logger_backend_class factory function.""" + assert get_logger_backend_class("console") == ConsoleBackend + assert get_logger_backend_class("wandb") == WandbBackend + + with pytest.raises(ValueError, match="Unknown logger backend type"): + get_logger_backend_class("invalid_backend") + + +class TestMetricCollector: + """Test MetricCollector singleton behavior.""" + + def test_singleton_per_rank(self, mock_rank): + """Test MetricCollector singleton behavior per rank.""" + mock_rank.return_value.rank = 0 + collector1 = MetricCollector() + collector2 = MetricCollector() + assert collector1 is collector2 + + # Different rank should get different instance + mock_rank.return_value.rank = 1 + collector3 = MetricCollector() + assert collector1 is not collector3 + + +class TestMetricActorDisabling: + """Test environment flag to disable metric actors.""" + + async def _test_fetcher_registration(self, env_var_value, should_register_fetchers): + """Check if FORGE_DISABLE_METRICS=[True, False, None] correctly disables fetcher registration. + + Args: + env_var_value: Value to set for FORGE_DISABLE_METRICS (None means unset) + should_register_fetchers: Whether fetchers should be registered (True) or not (False) + """ + import os + + import forge.observability.metric_actors + from forge.env_constants import FORGE_DISABLE_METRICS + from monarch.actor import this_host + + # set fresh env + # Note: Environment variable setup is handled by clean_metrics_environment fixture + forge.observability.metric_actors._global_logger = None + + if env_var_value is not None: + os.environ[FORGE_DISABLE_METRICS] = env_var_value + + procs = this_host().spawn_procs(per_host={"cpus": 1}) + + if hasattr(procs, "_local_fetcher"): + delattr(procs, "_local_fetcher") + + # Test functionality + global_logger = await get_or_create_metric_logger(proc_mesh=procs) + + # Get results to check + proc_has_fetcher = hasattr(procs, "_local_fetcher") + global_has_fetcher = await global_logger.has_fetcher.call_one(procs) + + # Assert based on expected behavior + if should_register_fetchers: + assert ( + proc_has_fetcher + ), f"Expected process to have _local_fetcher when {env_var_value=}" + assert ( + global_has_fetcher + ), f"Expected global logger to have fetcher registered when {env_var_value=}" + else: + assert ( + not proc_has_fetcher + ), f"Expected process to NOT have _local_fetcher when {env_var_value=}" + assert ( + not global_has_fetcher + ), f"Expected global logger to NOT have fetcher registered when {env_var_value=}" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "env_value,should_register", + [ + ("false", True), + ("true", False), + (None, True), + ], + ) + async def test_fetcher_registration_with_env_flag(self, env_value, should_register): + """Test fetcher registration behavior with different environment flag values.""" + await self._test_fetcher_registration(env_value, should_register) From feb4771a5e5b36bb44cc8c0bbbfde7c3311bcf57 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 09:26:47 -0700 Subject: [PATCH 02/18] commit --- src/forge/env_constants.py | 2 +- src/forge/observability/__init__.py | 10 +- src/forge/observability/metric_actors.py | 24 +- src/forge/observability/metrics.py | 121 +++++++--- src/forge/observability/perf_tracker.py | 4 +- .../unit_tests/observability/test_metrics.py | 208 +++++++++++++++++- .../observability/test_perf_tracker.py | 10 +- 7 files changed, 329 insertions(+), 50 deletions(-) diff --git a/src/forge/env_constants.py b/src/forge/env_constants.py index a4e024d83..6e0fc30e7 100644 --- a/src/forge/env_constants.py +++ b/src/forge/env_constants.py @@ -11,7 +11,7 @@ # Force all timing methods in forge.observability.perf_tracker.py to use # CPU timer if False or GPU timer if True. If unset, defaults to the assigned value to the function. -METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA" +METRIC_TIMER_USES_GPU = "METRIC_TIMER_USES_GPU" # Makes forge.observability.metrics.record_metric a no-op # and disables spawning LocalFetcherActor in get_or_create_metric_logger diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 52262eed5..f37dacebd 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -10,15 +10,15 @@ LocalFetcherActor, ) from .metrics import ( + BackendRole, ConsoleBackend, - # Utility functions get_actor_name_with_rank, get_logger_backend_class, - # Backend classes LoggerBackend, + LoggingMode, MaxAccumulator, MeanAccumulator, - # Accumulator classes + Metric, MetricAccumulator, MetricCollector, MinAccumulator, @@ -41,8 +41,12 @@ # Performance tracking "Tracer", "trace", + # Data classes + "Metric", + "BackendRole", # Enums "Reduce", + "LoggingMode", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index edd1f24d8..85dd56150 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -106,7 +106,7 @@ async def get_or_create_metric_logger( "local_fetcher_actor", LocalFetcherActor, global_logger ) await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) - proc._local_fetcher = local_fetcher_actor + proc._local_fetcher = local_fetcher_actor # pyre-ignore return global_logger @@ -125,13 +125,13 @@ def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None @endpoint async def flush( - self, step: int, return_state: bool = False + self, global_step: int, return_state: bool = False ) -> Dict[str, Dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. This should only ever be called by the global logger. Args: - step (int): train step used by backends to align all metrics on the same x-axis + global_step (int): step used by backends to align all metrics on the same x-axis return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: @@ -139,7 +139,7 @@ async def flush( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ collector = MetricCollector() - result = await collector.flush(step, return_state=return_state) + result = await collector.flush(global_step, return_state=return_state) return result @endpoint @@ -260,13 +260,13 @@ async def deregister_fetcher(self, name: str | ProcMesh): del self.fetchers[name] @endpoint - async def flush(self, step: int): + async def flush(self, global_step: int): """ Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors log to local backends and return states if needed for cross-rank reduction. Args: - step (int): Global step for logging. + global_step (int): step for logging. """ if not self.fetchers: return @@ -285,12 +285,14 @@ async def flush(self, step: int): for backend_config in config.values() ) - logger.debug(f"Global flush for step {step}: {len(self.fetchers)} fetchers") + logger.debug( + f"Global flush for global_step {global_step}: {len(self.fetchers)} fetchers" + ) # Broadcast flush to all fetchers results = await asyncio.gather( *[ - f.flush.call(step, return_state=requires_reduce) + f.flush.call(global_step, return_state=requires_reduce) for f in self.fetchers.values() ], return_exceptions=True, @@ -314,10 +316,10 @@ async def flush(self, step: int): ) if not all_local_states: - logger.warning(f"No states to reduce for step {step}") + logger.warning(f"No states to reduce for global_step {global_step}") return - # Reduce + # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) # Log to each global logger_backend @@ -325,7 +327,7 @@ async def flush(self, step: int): logger_backend_name, logger_backend, ) in self.global_logger_backends.items(): - await logger_backend.log(reduced_metrics, step) + await logger_backend.log(reduced_metrics, global_step) @endpoint def has_fetcher(self, name: str | ProcMesh) -> bool: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 4d527ec1b..a32e6d7a1 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -5,17 +5,56 @@ # LICENSE file in the root directory of this source tree. import logging - import os from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional +import pytz from monarch.actor import context, current_rank logger = logging.getLogger(__name__) +class BackendRole: + """Backend role constants for metric logging actors. + + Defines whether an actor operates as a local (per-rank) or global (controller) role + in the distributed metrics collection system. + """ + + LOCAL: str = "local" + GLOBAL: str = "global" + + +class LoggingMode(Enum): + """Metric logging behavior for distributed training scenarios. + + Each mode serves different observability needs: + + GLOBAL_REDUCE = "global_reduce" + Best for: Metrics that are best visualized as a single value per step. + Behavior: All ranks accumulate → controller reduces → single log entry + Example use: 8 ranks training, want 1 loss value per step averaged across all + + PER_RANK_REDUCE = "per_rank_reduce" + Best for: Per-rank performance metrics, debugging individual rank behavior + Behavior: Each rank accumulates + logs its own reduced values + Example use: Monitor GPU utilization per rank, get 8 separate log entries per step + + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + Best for: Real-time streaming, time-series debugging + Behavior: Raw values logged immediately on record_metric() calls + Example use: See what every rank is doing in real time. + """ + + GLOBAL_REDUCE = "global_reduce" + PER_RANK_REDUCE = "per_rank_reduce" + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + + class Reduce(Enum): MEAN = "mean" SUM = "sum" @@ -35,6 +74,24 @@ def accumulator_class(self): return mapping[self] +@dataclass +class Metric: + """Container for metric data including key, value, reduction type, and timestamp. + + Timestamp is automatically set to current EST time if not provided. + """ + + key: str + value: Any + reduction: Reduce + timestamp: Optional[float] = None + + def __post_init__(self): + if self.timestamp is None: + # Always record in UTC timezone + self.timestamp = datetime.now(pytz.UTC).timestamp() + + def get_actor_name_with_rank() -> str: """ Extracts actor information from Monarch context to form a logging name. @@ -109,8 +166,8 @@ def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None collector.push(key, value, reduction) -def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, Any]: - """Reduce metric accumulators states to a single value per metric. +def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metric]: + """Reduce metric accumulators states to a list of metrics. Can be used when reducing metrics across ranks or services, as merging states is more precise than merging locally reduced metrics. @@ -120,7 +177,7 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. Returns: - Dict[str, Any]: Dictionary with format {metric_key: reduced_value} + List[Metric]: List of reduced metrics Example: states = [ @@ -128,18 +185,18 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}}, ] reduce_metrics_states(states) - >>> {"loss": 2.0} + >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)] Raises: ValueError: on mismatched reduction types for the same metric key. """ if not states: - return {} + return [] # Collect unique keys across all all_keys = set(k for state in states for k in state) - reduced_metrics = {} + reduced_metrics = [] for key in all_keys: metric_states = [state.get(key) for state in states if key in state] if not metric_states: @@ -158,7 +215,14 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, metric_accumulator = Reduce(first_reduction_type).accumulator_class reduced_value = metric_accumulator.get_reduced_value_from_states(metric_states) - reduced_metrics[key] = reduced_value + + # Create Metric object with reduced value + metric = Metric( + key=key, + value=reduced_value, + reduction=Reduce(first_reduction_type), + ) + reduced_metrics.append(metric) return reduced_metrics @@ -459,12 +523,12 @@ def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: self.accumulators[key].append(value) async def flush( - self, step: int, return_state: bool = False + self, global_step: int, return_state: bool = False ) -> Dict[str, Dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. Args: - step (int): Step used by backends to align metrics on the same x-axis + global_step (int): step used by backends to align metrics on the same x-axis return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: @@ -487,7 +551,7 @@ async def flush( if not self.accumulators: logger.debug( - f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for step {step}" + f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for global_step {global_step}" ) return {} @@ -499,14 +563,12 @@ async def flush( # Reduce metrics from states for logging if any per-rank backend if self.logger_backends: - metrics = {} - for key, state in states.items(): - acc_class = Reduce(state["reduction_type"]).accumulator_class - metrics[key] = acc_class.get_reduced_value_from_states([state]) + # Use reduce_metrics_states for consistency + reduced_metrics = reduce_metrics_states([states]) # Log to local logger_backends for logger_backend in self.logger_backends: - await logger_backend.log(metrics, step) + await logger_backend.log(reduced_metrics, global_step) return states if return_state else {} @@ -554,7 +616,7 @@ async def init( primary_logger_metadata = {} pass - async def log(self, metrics: Dict[str, Any], step: int) -> None: + async def log(self, metrics: List[Metric], global_step: int) -> None: pass async def finish(self) -> None: @@ -582,11 +644,14 @@ async def init( else "GLOBAL" ) - async def log(self, metrics: Dict[str, Any], step: int) -> None: - logger.info(f"=== [{self.prefix}] - METRICS STEP {step} ===") - for key, value in sorted(metrics.items()): - logger.info(f" {key}: {value}") - logger.info("==============================\n") + async def log(self, metrics: List[Metric], global_step: int) -> None: + metrics_str = "\n".join( + f" {metric.key}: {metric.value}" + for metric in sorted(metrics, key=lambda m: m.key) + ) + logger.info( + f"=== [{self.prefix}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" + ) async def finish(self) -> None: pass @@ -701,11 +766,17 @@ async def _init_shared_local(self, primary_metadata: Dict[str, Any]): settings=settings, ) - async def log(self, metrics: Dict[str, Any], step: int) -> None: + async def log(self, metrics: List[Metric], global_step: int) -> None: if self.run: - log_data = {**metrics, "global_step": step} + # Convert metrics to WandB log format + log_data = {"global_step": global_step} + for metric in metrics: + log_data[metric.key] = metric.value + self.run.log(log_data) - logger.info(f"WandbBackend: Logged {len(metrics)} metrics at step {step}") + logger.info( + f"WandbBackend: Logged {len(metrics)} metrics at global_step {global_step}" + ) else: logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") diff --git a/src/forge/observability/perf_tracker.py b/src/forge/observability/perf_tracker.py index e85b81e26..47577d916 100644 --- a/src/forge/observability/perf_tracker.py +++ b/src/forge/observability/perf_tracker.py @@ -15,7 +15,7 @@ import torch -from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_CUDA +from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU from forge.observability.metrics import record_metric, Reduce # Thread-local memory tracking state @@ -125,7 +125,7 @@ def start(self) -> None: # Start timing (always enabled) time_with_gpu_events = ( - os.getenv(METRIC_TIMER_USES_CUDA, str(self.time_with_gpu)).lower() == "true" + os.getenv(METRIC_TIMER_USES_GPU, str(self.time_with_gpu)).lower() == "true" ) and torch.cuda.is_available() self._timer = _TimerCUDA() if time_with_gpu_events else _TimerCPU() self._timer.start() diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index 3e864bdf7..e65eb4f42 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -4,24 +4,224 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Unit tests for core metrics functionality focusing on critical fixes in Diff 1.""" +"""Unit tests for core metrics functionality.""" +import time from unittest.mock import MagicMock, patch import pytest from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import ( + BackendRole, ConsoleBackend, get_logger_backend_class, + LoggingMode, + MaxAccumulator, MeanAccumulator, + Metric, MetricCollector, + MinAccumulator, record_metric, Reduce, + reduce_metrics_states, + StdAccumulator, + SumAccumulator, WandbBackend, ) +class TestMetricCreation: + """Test Metric object creation and record_metric function - Diff 2 features.""" + + def test_metric_creation_automatic_timestamp(self, mock_rank): + """Test Metric object creation with automatic timestamp.""" + before_time = time.time() + metric = Metric("test_key", 42.0, Reduce.MEAN) + after_time = time.time() + + assert metric.key == "test_key" + assert metric.value == 42.0 + assert metric.reduction == Reduce.MEAN + assert metric.timestamp is not None + assert before_time <= metric.timestamp <= after_time + + def test_metric_creation_custom_timestamp(self, mock_rank): + """Test Metric object creation with custom timestamp.""" + custom_time = 1234567890.0 + metric = Metric("test_key2", 24.0, Reduce.SUM, timestamp=custom_time) + assert metric.timestamp == custom_time + + def test_record_metric(self, mock_rank): + """Test record_metric calls collector correctly.""" + # Mock the MetricCollector constructor to return a mock instance + mock_collector = MagicMock() + + with patch( + "forge.observability.metrics.MetricCollector", return_value=mock_collector + ): + record_metric("loss", 1.5, Reduce.MEAN) + + # Verify push was called on the mock collector with correct parameters + mock_collector.push.assert_called_once_with("loss", 1.5, Reduce.MEAN) + + def test_new_enums_and_constants(self): + """Test new LoggingMode enum and BackendRole constants.""" + # Test LoggingMode enum values + assert LoggingMode.GLOBAL_REDUCE.value == "global_reduce" + assert LoggingMode.PER_RANK_REDUCE.value == "per_rank_reduce" + assert LoggingMode.PER_RANK_NO_REDUCE.value == "per_rank_no_reduce" + + # Test BackendRole constants + assert BackendRole.LOCAL == "local" + assert BackendRole.GLOBAL == "global" + + +class TestReduceOperations: + """Test reduce_metrics_states function returning List[Metric] - Diff 2 feature.""" + + def test_empty_states(self): + """Test reduce_metrics_states with empty input.""" + result = reduce_metrics_states([]) + assert result == [] + + def test_single_state(self): + """Test reduce_metrics_states with single state.""" + states = [{"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}] + result = reduce_metrics_states(states) + assert len(result) == 1 + assert result[0].key == "loss" + assert result[0].value == 5.0 + assert result[0].reduction == Reduce.MEAN + + def test_multiple_states(self): + """Test reduce_metrics_states with multiple states.""" + states = [ + {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}, + {"loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}}, + {"accuracy": {"reduction_type": "sum", "total": 15.0}}, + ] + result = reduce_metrics_states(states) + + # Convert to dict for easier testing + result_dict = {metric.key: metric.value for metric in result} + assert result_dict["loss"] == 30.0 / 5.0 # 6.0 + assert result_dict["accuracy"] == 15.0 + + # Also check reduction types + for metric in result: + if metric.key == "loss": + assert metric.reduction == Reduce.MEAN + elif metric.key == "accuracy": + assert metric.reduction == Reduce.SUM + + def test_mismatched_reduction_types_raises_error(self): + """Test reduce_metrics_states raises error for mismatched reduction types.""" + states = [ + {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}, + {"loss": {"reduction_type": "sum", "total": 20.0}}, + ] + with pytest.raises(ValueError, match="Mismatched reduction types"): + reduce_metrics_states(states) + + +class TestAccumulators: + """Test all accumulator classes and their operations - Diff 2 extensions.""" + + def test_sum_accumulator(self): + """Test SumAccumulator operations.""" + acc = SumAccumulator(Reduce.SUM) + + acc.append(5.0) + acc.append(3.0) + assert acc.get_value() == 8.0 + + state = acc.get_state() + assert state["total"] == 8.0 + assert state["reduction_type"] == "sum" + + acc.reset() + assert acc.get_value() == 0.0 + + def test_max_accumulator(self): + """Test MaxAccumulator operations.""" + acc = MaxAccumulator(Reduce.MAX) + + acc.append(5.0) + acc.append(10.0) + acc.append(3.0) + assert acc.get_value() == 10.0 + + state = acc.get_state() + assert state["max_val"] == 10.0 + assert state["reduction_type"] == "max" + + def test_min_accumulator(self): + """Test MinAccumulator operations.""" + acc = MinAccumulator(Reduce.MIN) + + acc.append(5.0) + acc.append(10.0) + acc.append(3.0) + assert acc.get_value() == 3.0 + + state = acc.get_state() + assert state["min_val"] == 3.0 + assert state["reduction_type"] == "min" + + def test_std_accumulator(self): + """Test StdAccumulator operations.""" + acc = StdAccumulator(Reduce.STD) + + # Test with zero/one values + assert acc.get_value() == 0.0 + acc.append(5.0) + assert acc.get_value() == 0.0 # std of single value is 0 + + # Test with multiple values + acc.append(7.0) # values: 5, 7, mean=6, std=1 + assert abs(acc.get_value() - 1.0) < 0.001 + + state = acc.get_state() + assert state["sum"] == 12.0 + assert state["sum_sq"] == 74.0 # 5^2 + 7^2 = 25 + 49 = 74 + assert state["count"] == 2 + + @pytest.mark.parametrize( + "accumulator_class,states,expected", + [ + ( + MeanAccumulator, + [ + {"reduction_type": "mean", "sum": 10.0, "count": 2}, + {"reduction_type": "mean", "sum": 20.0, "count": 3}, + ], + 6.0, # (10+20) / (2+3) + ), + ( + SumAccumulator, + [ + {"reduction_type": "sum", "total": 10.0}, + {"reduction_type": "sum", "total": 15.0}, + ], + 25.0, + ), + ], + ) + def test_accumulator_state_reduction(self, accumulator_class, states, expected): + """Test cross-accumulator state reduction.""" + result = accumulator_class.get_reduced_value_from_states(states) + assert result == expected + + def test_reduce_enum_accumulator_mapping(self): + """Test that Reduce enum correctly maps to accumulator classes.""" + assert Reduce.MEAN.accumulator_class == MeanAccumulator + assert Reduce.SUM.accumulator_class == SumAccumulator + assert Reduce.MAX.accumulator_class == MaxAccumulator + assert Reduce.MIN.accumulator_class == MinAccumulator + assert Reduce.STD.accumulator_class == StdAccumulator + + class TestCriticalFixes: """Test critical production fixes from Diff 1.""" @@ -41,7 +241,7 @@ async def test_uninitialized_flush_logs_warning(self, mock_rank, caplog): collector = MetricCollector() # Should not raise error, just log warning and return empty dict - result = await collector.flush(step=1, return_state=True) + result = await collector.flush(global_step=1, return_state=True) assert result == {} assert any( "Cannot flush collected metrics" in record.message @@ -98,7 +298,9 @@ async def test_console_backend(self, mock_actor_name): await backend.init(role="local") # Test log - should not raise - await backend.log({"test": 1.0}, step=1) + # Create a test metric + test_metric = Metric("test", 1.0, Reduce.MEAN) + await backend.log([test_metric], global_step=1) await backend.finish() # Should not raise diff --git a/tests/unit_tests/observability/test_perf_tracker.py b/tests/unit_tests/observability/test_perf_tracker.py index 6af7331f1..01d1603d1 100644 --- a/tests/unit_tests/observability/test_perf_tracker.py +++ b/tests/unit_tests/observability/test_perf_tracker.py @@ -12,7 +12,7 @@ import pytest import torch -from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_CUDA +from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU from forge.observability.metrics import Reduce from forge.observability.perf_tracker import _TimerCPU, _TimerCUDA, trace, Tracer @@ -135,7 +135,7 @@ def test_comprehensive_workflow( if timer == "gpu" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - monkeypatch.setenv(METRIC_TIMER_USES_CUDA, str(timer == "gpu")) + monkeypatch.setenv(METRIC_TIMER_USES_GPU, str(timer == "gpu")) async def run_concurrent_tasks(): start_time = time.perf_counter() @@ -370,17 +370,17 @@ async def disabled_workflow(): ("false", _TimerCPU), ], ) - def test_metric_timer_uses_cuda_override( + def test_metric_timer_uses_gpu_override( self, env_value, expected_backend, monkeypatch ): - """Test METRIC_TIMER_USES_CUDA env var overrides timer parameter.""" + """Test METRIC_TIMER_USES_GPU env var overrides timer parameter.""" if env_value == "true" and not torch.cuda.is_available(): pytest.skip("CUDA not available") with patch("torch.cuda.is_available", return_value=True), patch( "forge.observability.perf_tracker.record_metric" ): - monkeypatch.setenv(METRIC_TIMER_USES_CUDA, env_value) + monkeypatch.setenv(METRIC_TIMER_USES_GPU, env_value) # Test with timer="cpu" (should be overridden by env) tracer = Tracer("env_test", timer="cpu") From 41ceaa48d076a121719807d80bfbd524ac771d8f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:06:48 -0700 Subject: [PATCH 03/18] update backend role typehints and enum --- src/forge/observability/__init__.py | 2 - src/forge/observability/metric_actors.py | 20 ++--- src/forge/observability/metrics.py | 79 ++++++------------- .../unit_tests/observability/test_metrics.py | 42 +++++++--- 4 files changed, 68 insertions(+), 75 deletions(-) diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index f37dacebd..b970e57fa 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -15,7 +15,6 @@ get_actor_name_with_rank, get_logger_backend_class, LoggerBackend, - LoggingMode, MaxAccumulator, MeanAccumulator, Metric, @@ -46,7 +45,6 @@ "BackendRole", # Enums "Reduce", - "LoggingMode", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 85dd56150..0c4d15c34 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -13,6 +13,7 @@ from forge.env_constants import FORGE_DISABLE_METRICS from forge.observability.metrics import ( + BackendRole, get_logger_backend_class, LoggerBackend, MetricCollector, @@ -147,14 +148,13 @@ async def init_backends( self, metadata_per_primary_backend: Dict[str, Dict[str, Any]], config: Dict[str, Any], - ): + ) -> None: """Init local (per-rank) logger backends and MetricCollector.""" collector = MetricCollector() await collector.init_backends(metadata_per_primary_backend, config) @endpoint - async def shutdown(self): - + async def shutdown(self) -> None: collector = MetricCollector() await collector.shutdown() @@ -185,7 +185,7 @@ def __init__(self): self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {} @endpoint - async def init_backends(self, config: Dict[str, Any]): + async def init_backends(self, config: Dict[str, Any]) -> None: """ Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors in all registered fetchers. @@ -208,7 +208,7 @@ async def init_backends(self, config: Dict[str, Any]): for backend_name, backend_config in config.items(): backend = get_logger_backend_class(backend_name)(backend_config) - await backend.init(role="global") + await backend.init(role=BackendRole.GLOBAL) # Extract metadata from primary logger to be shared with secondary loggers # and store it @@ -236,7 +236,9 @@ async def init_backends(self, config: Dict[str, Any]): await asyncio.gather(*tasks, return_exceptions=True) @endpoint - async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMesh): + async def register_fetcher( + self, fetcher: LocalFetcherActor, name: str | ProcMesh + ) -> None: """Registers a fetcher with the global actor. Each key represents a process mesh. If there are 2 processes, each with 2 replicas with N gpus, we would have 4 keys, i.e. 2 proces meshes, each with 2 replicas.""" @@ -250,7 +252,7 @@ async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMes ) @endpoint - async def deregister_fetcher(self, name: str | ProcMesh): + async def deregister_fetcher(self, name: str | ProcMesh) -> None: if name not in self.fetchers: logger.warning( f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister." @@ -260,7 +262,7 @@ async def deregister_fetcher(self, name: str | ProcMesh): del self.fetchers[name] @endpoint - async def flush(self, global_step: int): + async def flush(self, global_step: int) -> None: """ Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors log to local backends and return states if needed for cross-rank reduction. @@ -339,7 +341,7 @@ def get_fetcher_count(self) -> int: return len(self.fetchers) @endpoint - async def shutdown(self): + async def shutdown(self) -> None: # Finish per-rank logger_backends via fetchers if self.fetchers: tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()] diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index a32e6d7a1..5ba396d7f 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -18,41 +18,15 @@ logger = logging.getLogger(__name__) -class BackendRole: +class BackendRole(Enum): """Backend role constants for metric logging actors. Defines whether an actor operates as a local (per-rank) or global (controller) role in the distributed metrics collection system. """ - LOCAL: str = "local" - GLOBAL: str = "global" - - -class LoggingMode(Enum): - """Metric logging behavior for distributed training scenarios. - - Each mode serves different observability needs: - - GLOBAL_REDUCE = "global_reduce" - Best for: Metrics that are best visualized as a single value per step. - Behavior: All ranks accumulate → controller reduces → single log entry - Example use: 8 ranks training, want 1 loss value per step averaged across all - - PER_RANK_REDUCE = "per_rank_reduce" - Best for: Per-rank performance metrics, debugging individual rank behavior - Behavior: Each rank accumulates + logs its own reduced values - Example use: Monitor GPU utilization per rank, get 8 separate log entries per step - - PER_RANK_NO_REDUCE = "per_rank_no_reduce" - Best for: Real-time streaming, time-series debugging - Behavior: Raw values logged immediately on record_metric() calls - Example use: See what every rank is doing in real time. - """ - - GLOBAL_REDUCE = "global_reduce" - PER_RANK_REDUCE = "per_rank_reduce" - PER_RANK_NO_REDUCE = "per_rank_no_reduce" + LOCAL = "local" + GLOBAL = "global" class Reduce(Enum): @@ -235,7 +209,7 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metri class MetricAccumulator(ABC): """Every metric maps to a MetricAccumulator, which accumulates values and optionally reduces them.""" - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: self.reduction_type = reduction @abstractmethod @@ -266,7 +240,7 @@ def reset(self) -> None: class MeanAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.sum = 0.0 self.count = 0 @@ -298,7 +272,7 @@ def reset(self) -> None: class SumAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.total = 0.0 @@ -321,7 +295,7 @@ def reset(self) -> None: class MaxAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.max_val = float("-inf") @@ -344,7 +318,7 @@ def reset(self) -> None: class MinAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.min_val = float("inf") @@ -367,7 +341,7 @@ def reset(self) -> None: class StdAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.sum = 0.0 self.sum_sq = 0.0 @@ -453,7 +427,7 @@ def __new__(cls): ) return inst - def __init__(self): + def __init__(self) -> None: if hasattr(self, "_is_initialized"): return @@ -493,7 +467,7 @@ async def init_backends( # instantiate local backend logger_backend = get_logger_backend_class(backend_name)(backend_config) await logger_backend.init( - role="local", primary_logger_metadata=primary_metadata + role=BackendRole.LOCAL, primary_logger_metadata=primary_metadata ) self.logger_backends.append(logger_backend) @@ -592,20 +566,20 @@ async def shutdown(self): class LoggerBackend(ABC): """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.""" - def __init__(self, logger_backend_config: Dict[str, Any]): + def __init__(self, logger_backend_config: Dict[str, Any]) -> None: self.logger_backend_config = logger_backend_config @abstractmethod async def init( self, - role: str, + role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). Args: - role (str): "global" (controller/primary) or "local" (per-rank/secondary). + role (BackendRole): BackendRole.GLOBAL (controller/primary) or BackendRole.LOCAL (per-rank/secondary). Can be used to behave differently for primary vs secondary roles. primary_logger_metadata (Optional[Dict[str, Any]]): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. @@ -630,18 +604,18 @@ def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: class ConsoleBackend(LoggerBackend): """Simple console logging of metrics.""" - def __init__(self, logger_backend_config: Dict[str, Any]): + def __init__(self, logger_backend_config: Dict[str, Any]) -> None: super().__init__(logger_backend_config) async def init( self, - role: str, + role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: self.prefix = ( get_actor_name_with_rank() if self.logger_backend_config.get("reduce_across_ranks", True) - else "GLOBAL" + else "Controller" ) async def log(self, metrics: List[Metric], global_step: int) -> None: @@ -675,7 +649,7 @@ class WandbBackend(LoggerBackend): group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" """ - def __init__(self, logger_backend_config: Dict[str, Any]): + def __init__(self, logger_backend_config: Dict[str, Any]) -> None: super().__init__(logger_backend_config) self.project = logger_backend_config["project"] self.group = logger_backend_config.get("group", "experiment_group") @@ -688,25 +662,22 @@ def __init__(self, logger_backend_config: Dict[str, Any]): async def init( self, - role: str, + role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: if primary_logger_metadata is None: primary_logger_metadata = {} - if role not in ["global", "local"]: - raise ValueError( - f"Invalid role {role} for WandbBackend init. Must be 'global' or 'local'." - ) - self.name = ( - get_actor_name_with_rank() if role == "local" else "global_controller" + get_actor_name_with_rank() + if role == BackendRole.LOCAL + else "global_controller" ) # Default global mode: only inits on controller if self.reduce_across_ranks: - if role != "global": + if role != BackendRole.GLOBAL: logger.debug( f"Skipped init for global mode (reduce_across_ranks=True) and {role} role." ) @@ -714,10 +685,10 @@ async def init( await self._init_global() # Per-rank modes based on share_run_id bool - elif role == "global" and self.share_run_id: + elif role == BackendRole.GLOBAL and self.share_run_id: await self._init_shared_global() - elif role == "local": + elif role == BackendRole.LOCAL: if self.share_run_id: await self._init_shared_local(primary_logger_metadata) else: diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index e65eb4f42..ee635c582 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -16,7 +16,6 @@ BackendRole, ConsoleBackend, get_logger_backend_class, - LoggingMode, MaxAccumulator, MeanAccumulator, Metric, @@ -66,15 +65,38 @@ def test_record_metric(self, mock_rank): mock_collector.push.assert_called_once_with("loss", 1.5, Reduce.MEAN) def test_new_enums_and_constants(self): - """Test new LoggingMode enum and BackendRole constants.""" - # Test LoggingMode enum values - assert LoggingMode.GLOBAL_REDUCE.value == "global_reduce" - assert LoggingMode.PER_RANK_REDUCE.value == "per_rank_reduce" - assert LoggingMode.PER_RANK_NO_REDUCE.value == "per_rank_no_reduce" - - # Test BackendRole constants - assert BackendRole.LOCAL == "local" - assert BackendRole.GLOBAL == "global" + """Test BackendRole constants and usage.""" + # Test BackendRole enum values + assert BackendRole.LOCAL.value == "local" + assert BackendRole.GLOBAL.value == "global" + + # Test that BackendRole is a proper Enum + assert isinstance(BackendRole.LOCAL, BackendRole) + assert isinstance(BackendRole.GLOBAL, BackendRole) + + @patch("forge.observability.metrics.get_actor_name_with_rank") + @pytest.mark.asyncio + async def test_backend_role_usage(self, mock_actor_name): + """Test that BackendRole constants are actually used instead of string literals.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + # Test ConsoleBackend + console_backend = ConsoleBackend({}) + await console_backend.init(role=BackendRole.LOCAL) + + # Test WandbBackend role validation without WandB initialization + wandb_backend = WandbBackend({"project": "test"}) + + # Mock all the WandB init methods to focus only on role validation + with patch.object(wandb_backend, "_init_global"), patch.object( + wandb_backend, "_init_shared_global" + ), patch.object(wandb_backend, "_init_shared_local"), patch.object( + wandb_backend, "_init_per_rank" + ): + + # Should not raise error for valid roles (type system prevents invalid values) + await wandb_backend.init(role=BackendRole.GLOBAL) + await wandb_backend.init(role=BackendRole.LOCAL) class TestReduceOperations: From 8a24e715c15bb3d18337a75fc491a8a81e605291 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:17:08 -0700 Subject: [PATCH 04/18] update where we check FORGE_DISABLE_METRICS --- src/forge/controller/provisioner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 429c5760f..cf712079b 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -20,6 +20,7 @@ from forge.controller.launcher import BaseLauncher, get_launcher +from forge.env_constants import FORGE_DISABLE_METRICS from forge.observability.metric_actors import get_or_create_metric_logger from forge.types import ProcessConfig, ProvisionerConfig @@ -263,8 +264,8 @@ def bootstrap(env: dict[str, str]): self._proc_host_map[procs] = host_mesh # Spawn local fetcher actor on each process and register with global logger - # Can be disabled by FORGE_DISABLE_METRICS env var - _ = await get_or_create_metric_logger(procs) + if os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true": + _ = await get_or_create_metric_logger(procs) return procs async def host_mesh_from_proc(self, proc_mesh: ProcMesh): From 3f3bc51bd69316cd403c874a72b7e9824ae9f190 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:17:48 -0700 Subject: [PATCH 05/18] remove protected import --- src/forge/observability/metrics.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 4d527ec1b..64843f110 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -13,6 +13,8 @@ from monarch.actor import context, current_rank +from forge.util.logging import log_once + logger = logging.getLogger(__name__) @@ -437,8 +439,6 @@ async def init_backends( def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: if not self._is_initialized: - from forge.util.logging import log_once - log_once( logger, level=logging.WARNING, @@ -472,8 +472,6 @@ async def flush( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: - from forge.util.logging import log_once - log_once( logger, level=logging.WARNING, From 4fe26116d9562826f0fcc4cc37bbce48c40ccb18 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:23:18 -0700 Subject: [PATCH 06/18] protect import --- src/forge/controller/provisioner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index cf712079b..5ca331f32 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -21,7 +21,6 @@ from forge.controller.launcher import BaseLauncher, get_launcher from forge.env_constants import FORGE_DISABLE_METRICS -from forge.observability.metric_actors import get_or_create_metric_logger from forge.types import ProcessConfig, ProvisionerConfig @@ -265,6 +264,8 @@ def bootstrap(env: dict[str, str]): # Spawn local fetcher actor on each process and register with global logger if os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true": + from forge.observability.metric_actors import get_or_create_metric_logger + _ = await get_or_create_metric_logger(procs) return procs @@ -286,6 +287,10 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): async with self._lock: # Deregister local logger from global logger if hasattr(proc_mesh, "_local_fetcher"): + from forge.observability.metric_actors import ( + get_or_create_metric_logger, + ) + global_logger = await get_or_create_metric_logger(proc_mesh) await global_logger.deregister_fetcher.call_one(proc_mesh) From d81a4edd0b4677a03d5c2484bdc2923fe2fef0e1 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 11:22:37 -0700 Subject: [PATCH 07/18] record_metric uses dataclass Metric --- src/forge/observability/metrics.py | 47 +++++++++++++++---- .../unit_tests/observability/test_metrics.py | 17 +++++-- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index f08d8d637..3c5386af9 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -118,8 +118,7 @@ def get_actor_name_with_rank() -> str: def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: - """ - Records a metric value for later reduction and logging. + """Thin wrapper to send metrics to per-rank local MetricCollectors. Relies on a per-rank MetricCollector singleton for ease of use, i.e. call `record_metric` anywhere in the code without moving the @@ -134,12 +133,14 @@ def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None Can be disabled globally by setting the environment variable `FORGE_DISABLE_METRICS=true`. """ - # Skip metrics collection if disabled for tests + # Skip metrics collection if os.getenv("FORGE_DISABLE_METRICS", "false").lower() == "true": return + # timestamp is added automatically by the Metric class + metric = Metric(key=key, value=value, reduction=reduction) collector = MetricCollector() - collector.push(key, value, reduction) + collector.push(metric) def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metric]: @@ -475,7 +476,20 @@ async def init_backends( self._is_initialized = True - def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: + def push(self, metric: Metric) -> None: + """Process a metric according to configured logging modes. + + Args: + metric: Metric dataclass containing key, value, reduction type, and timestamp. + + Raises: + TypeError: If metric is not a Metric object. + + Example: + collector = MetricCollector() + metric = Metric("loss", 0.5, Reduce.MEAN) + collector.push(metric) + """ if not self._is_initialized: log_once( logger, @@ -491,10 +505,17 @@ def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: ) return - if key not in self.accumulators: - self.accumulators[key] = reduction.accumulator_class(reduction) + # Validate metric object + if not isinstance(metric, Metric): + raise TypeError(f"Expected {Metric} object, got {type(metric)}") - self.accumulators[key].append(value) + # Always accumulate for reduction and state return + key = metric.key + if key not in self.accumulators: + self.accumulators[key] = metric.reduction.accumulator_class( + metric.reduction + ) + self.accumulators[key].append(metric.value) async def flush( self, global_step: int, return_state: bool = False @@ -584,11 +605,17 @@ async def init( Raises: ValueError if missing metadata for shared local init. """ - if primary_logger_metadata is None: - primary_logger_metadata = {} pass + @abstractmethod async def log(self, metrics: List[Metric], global_step: int) -> None: + """ + Log a batch of metrics to the backend. + + Args: + metrics: List of Metric objects to log. + global_step: Step number for x-axis alignment across metrics. + """ pass async def finish(self) -> None: diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index ee635c582..563f52e6c 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -52,7 +52,7 @@ def test_metric_creation_custom_timestamp(self, mock_rank): assert metric.timestamp == custom_time def test_record_metric(self, mock_rank): - """Test record_metric calls collector correctly.""" + """Test record_metric creates correct Metric and calls collector.""" # Mock the MetricCollector constructor to return a mock instance mock_collector = MagicMock() @@ -61,8 +61,14 @@ def test_record_metric(self, mock_rank): ): record_metric("loss", 1.5, Reduce.MEAN) - # Verify push was called on the mock collector with correct parameters - mock_collector.push.assert_called_once_with("loss", 1.5, Reduce.MEAN) + # Verify push was called on the mock collector + mock_collector.push.assert_called_once() + + # Verify the metric passed to push + pushed_metric = mock_collector.push.call_args[0][0] + assert pushed_metric.key == "loss" + assert pushed_metric.value == 1.5 + assert pushed_metric.reduction == Reduce.MEAN def test_new_enums_and_constants(self): """Test BackendRole constants and usage.""" @@ -250,9 +256,10 @@ class TestCriticalFixes: def test_uninitialized_push_logs_warning(self, mock_rank, caplog): """Test MetricCollector.push() logs warning when uninitialized.""" collector = MetricCollector() + metric = Metric("test", 1.0, Reduce.MEAN) # Should not raise error, just log warning and return - collector.push("test", 1.0, Reduce.MEAN) + collector.push(metric) assert any( "Metric logging backends" in record.message for record in caplog.records ) @@ -317,7 +324,7 @@ async def test_console_backend(self, mock_actor_name): backend = ConsoleBackend({}) - await backend.init(role="local") + await backend.init(role=BackendRole.LOCAL) # Test log - should not raise # Create a test metric From 1e2255d2538d5129ff3d28412b265907c08d2b50 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 12:00:33 -0700 Subject: [PATCH 08/18] commit --- apps/grpo/main.py | 2 +- apps/grpo/qwen3_1_7b.yaml | 4 +- src/forge/controller/provisioner.py | 4 +- src/forge/observability/__init__.py | 6 +- src/forge/observability/metric_actors.py | 37 +++- src/forge/observability/metrics.py | 85 +++------ src/forge/observability/utils.py | 96 +++++++++++ tests/sandbox/toy_rl/toy_metrics/main.py | 2 +- tests/sandbox/vllm/main.py | 2 +- tests/unit_tests/observability/conftest.py | 7 +- .../observability/test_metric_actors.py | 162 ++++++++++++++++++ .../unit_tests/observability/test_metrics.py | 14 +- 12 files changed, 328 insertions(+), 93 deletions(-) create mode 100644 src/forge/observability/utils.py create mode 100644 tests/unit_tests/observability/test_metric_actors.py diff --git a/apps/grpo/main.py b/apps/grpo/main.py index c64f00bc2..770c7b9ac 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -319,7 +319,7 @@ async def main(cfg: DictConfig): ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) await ts.initialize(strategy=ts.ControllerStorageVolumes()) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 53eec5cfb..0e87cc6cf 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -18,9 +18,9 @@ metric_logging: wandb: project: "grpo-training" group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + reduce_across_ranks: False console: - reduce_across_ranks: True + reduce_across_ranks: False # Dataset configuration dataset: diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 5ca331f32..755255071 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -265,8 +265,10 @@ def bootstrap(env: dict[str, str]): # Spawn local fetcher actor on each process and register with global logger if os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true": from forge.observability.metric_actors import get_or_create_metric_logger + from forge.observability.utils import detect_actor_name_from_call_stack - _ = await get_or_create_metric_logger(procs) + process_name = detect_actor_name_from_call_stack() + _ = await get_or_create_metric_logger(procs, process_name=process_name) return procs async def host_mesh_from_proc(self, proc_mesh: ProcMesh): diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index b970e57fa..4a55ee87e 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -12,7 +12,6 @@ from .metrics import ( BackendRole, ConsoleBackend, - get_actor_name_with_rank, get_logger_backend_class, LoggerBackend, MaxAccumulator, @@ -29,12 +28,12 @@ WandbBackend, ) from .perf_tracker import trace, Tracer +from .utils import detect_actor_name_from_call_stack, get_actor_name_with_rank __all__ = [ # Main API functions "record_metric", "reduce_metrics_states", - "get_actor_name_with_rank", "get_logger_backend_class", "get_or_create_metric_logger", # Performance tracking @@ -45,6 +44,9 @@ "BackendRole", # Enums "Reduce", + # Utility functions + "detect_actor_name_from_call_stack", + "get_actor_name_with_rank", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 0c4d15c34..e9ba8a21d 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -19,6 +19,7 @@ MetricCollector, reduce_metrics_states, ) +from forge.observability.utils import detect_actor_name_from_call_stack logger = logging.getLogger(__name__) @@ -27,6 +28,7 @@ async def get_or_create_metric_logger( proc_mesh: ProcMesh | None = None, + process_name: str | None = None, ) -> "GlobalLoggingActor": """Initializes a LocalFetcherActor in the specified process mesh (or current process if None), if not already initialized, registers it with the GlobalLoggingActor and returns the @@ -40,6 +42,8 @@ async def get_or_create_metric_logger( Args: proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `monarch.actor.this_proc()`. + process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging. + If None, will auto-detect from call stack or default to "UnknownActor" if not found. Returns: GlobalLoggingActor: The global logging controller. @@ -53,7 +57,7 @@ async def get_or_create_metric_logger( from forge.observability.metrics import record_metric # Main process setup - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") # Initialize logging backends await mlogger.init_backends({ @@ -66,13 +70,17 @@ async def get_or_create_metric_logger( # Training loop for step in range(max_steps): - record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN) + record_metric("loss", 1.2, reduction_type=Reduce.MEAN) # ... training code with record_metric() calls ... await mlogger.flush(step) # Log metrics for this step # Shutdown await mlogger.shutdown() """ + + if process_name is None: + process_name = detect_actor_name_from_call_stack() + # Get or create the singleton global logger global _global_logger if _global_logger is None: @@ -104,7 +112,7 @@ async def get_or_create_metric_logger( and os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true" ): local_fetcher_actor = proc.spawn( - "local_fetcher_actor", LocalFetcherActor, global_logger + "local_fetcher_actor", LocalFetcherActor, global_logger, process_name ) await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) proc._local_fetcher = local_fetcher_actor # pyre-ignore @@ -120,8 +128,13 @@ class LocalFetcherActor(Actor): GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector """ - def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None: + def __init__( + self, + global_logger: Optional["GlobalLoggingActor"] = None, + process_name: str | None = None, + ) -> None: self.global_logger = global_logger + self.process_name = process_name # Passed to MetricCollector for logging _is_initialized = False @endpoint @@ -148,10 +161,22 @@ async def init_backends( self, metadata_per_primary_backend: Dict[str, Dict[str, Any]], config: Dict[str, Any], + global_step: int = 0, ) -> None: - """Init local (per-rank) logger backends and MetricCollector.""" + """Init local (per-rank) logger backends and MetricCollector. + + Args: + metadata_per_primary_backend (Dict[str, Dict[str, Any]]): Metadata from primary backends for shared state. + config (Dict[str, Any]): Backend configurations with logging modes and settings. + global_step (int): Initial step for metrics. + """ collector = MetricCollector() - await collector.init_backends(metadata_per_primary_backend, config) + await collector.init_backends( + metadata_per_primary_backend, + config, + global_step, + process_name=self.process_name, + ) @endpoint async def shutdown(self) -> None: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 3c5386af9..45b0af8dc 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -13,8 +13,9 @@ from typing import Any, Dict, List, Optional import pytz -from monarch.actor import context, current_rank +from monarch.actor import current_rank +from forge.observability.utils import get_actor_name_with_rank from forge.util.logging import log_once logger = logging.getLogger(__name__) @@ -68,55 +69,6 @@ def __post_init__(self): self.timestamp = datetime.now(pytz.UTC).timestamp() -def get_actor_name_with_rank() -> str: - """ - Extracts actor information from Monarch context to form a logging name. - - Returns: - str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). - Falls back to "UnknownActor" if context unavailable. - """ - # Add more defensive checks - ctx = context() - if ctx is None or ctx.actor_instance is None: - logger.warning("Context unavailable, using fallback actor name for logging.") - return "UnknownActor" - - actor_instance = ctx.actor_instance - rank = current_rank() - - actor_id_full = str(actor_instance.actor_id) - - # Parse the actor_id - parts = actor_id_full.split(".") - rank_name = "UnknownActor" # fallback - if len(parts) >= 2: - world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" - actor_part = parts[1] # e.g., "TestActorConfigured[0]" - - # Extract world ID and proc rank - world_id = world_part.split("[")[0] if "[" in world_part else world_part - - # Extract clean actor name (remove "Configured" suffix if present) - if "[" in actor_part: - actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" - if actor_name.endswith("Configured"): - actor_name = actor_name[:-10] # Remove "Configured" - else: - actor_name = actor_part - - # Use last 4 characters of world_id as replica identifier - # This is deterministic, readable, and works for any number of replicas - replica_id = world_id[-4:] if len(world_id) >= 4 else world_id - - # Use current_rank().rank as the local rank within the replica - local_rank = rank.rank - - rank_name = f"{actor_name}_{replica_id}_r{local_rank}" - - return rank_name - - def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: """Thin wrapper to send metrics to per-rank local MetricCollectors. @@ -443,6 +395,8 @@ async def init_backends( self, metadata_per_primary_backend: Optional[Dict[str, Dict[str, Any]]], config: Dict[str, Any], + global_step: int = 0, + process_name: str | None = None, ) -> None: """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated @@ -452,10 +406,13 @@ async def init_backends( metadata_per_primary_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from primary logger backend, e.g., {"wandb": {"run_id": "abc123"}}. config (Dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. + global_step (int, default 0): Initial step for metrics. + process_name (str | None): The meaningful process name for logging. """ if self._is_initialized: logger.debug(f"Rank {self.rank}: MetricCollector already initialized") return + self.global_step = global_step # instantiate local backends if any for backend_name, backend_config in config.items(): @@ -470,7 +427,9 @@ async def init_backends( # instantiate local backend logger_backend = get_logger_backend_class(backend_name)(backend_config) await logger_backend.init( - role=BackendRole.LOCAL, primary_logger_metadata=primary_metadata + role=BackendRole.LOCAL, + primary_logger_metadata=primary_metadata, + process_name=process_name, ) self.logger_backends.append(logger_backend) @@ -498,7 +457,7 @@ def push(self, metric: Metric) -> None: "Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized." " This happens when you try to use `record_metric` before calling `init_backends`." " To disable this warning, please call in your main file:\n" - "`mlogger = await get_or_create_metric_logger()`\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" "`await mlogger.init_backends.call_one(logging_config)`\n" "or set env variable `FORGE_DISABLE_METRICS=True`" ), @@ -544,7 +503,7 @@ async def flush( if not self.accumulators: logger.debug( - f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for global_step {global_step}" + f"Collector rank {self.rank}: No metrics to flush for global_step {global_step}" ) return {} @@ -569,7 +528,7 @@ async def shutdown(self): """Shutdown logger_backends if initialized.""" if not self._is_initialized: logger.debug( - f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" + f"Collector for rank {self.rank} not initialized. Skipping shutdown" ) return @@ -593,6 +552,7 @@ async def init( self, role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, + process_name: Optional[str] = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). @@ -602,6 +562,7 @@ async def init( Can be used to behave differently for primary vs secondary roles. primary_logger_metadata (Optional[Dict[str, Any]]): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. + process_name (str | None): Process name for logging. Raises: ValueError if missing metadata for shared local init. """ @@ -636,12 +597,9 @@ async def init( self, role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, + process_name: Optional[str] = None, ) -> None: - self.prefix = ( - get_actor_name_with_rank() - if self.logger_backend_config.get("reduce_across_ranks", True) - else "Controller" - ) + pass async def log(self, metrics: List[Metric], global_step: int) -> None: metrics_str = "\n".join( @@ -649,7 +607,7 @@ async def log(self, metrics: List[Metric], global_step: int) -> None: for metric in sorted(metrics, key=lambda m: m.key) ) logger.info( - f"=== [{self.prefix}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" + f"=== [METRICS STEP {global_step}] ===\n{metrics_str}\n==============================\n" ) async def finish(self) -> None: @@ -689,16 +647,13 @@ async def init( self, role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, + process_name: Optional[str] = None, ) -> None: if primary_logger_metadata is None: primary_logger_metadata = {} - self.name = ( - get_actor_name_with_rank() - if role == BackendRole.LOCAL - else "global_controller" - ) + self.name = get_actor_name_with_rank(process_name) # Default global mode: only inits on controller if self.reduce_across_ranks: diff --git a/src/forge/observability/utils.py b/src/forge/observability/utils.py new file mode 100644 index 000000000..f9fc18014 --- /dev/null +++ b/src/forge/observability/utils.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +from monarch.actor import context, current_rank + +logger = logging.getLogger(__name__) + + +def detect_actor_name_from_call_stack() -> str: + """Detect ForgeActor subclass name from call stack. + + Returns: + str: Actor name, defaulting to "UnknownActor" if not found. + """ + try: + import inspect + + frame = inspect.currentframe() + frame_count = 0 + + while frame: + frame = frame.f_back + if not frame: + break + + frame_count += 1 + if frame_count > 20: # Prevent infinite loops + break + + # Check for 'self' (instance method calls) + if "self" in frame.f_locals: + obj = frame.f_locals["self"] + if hasattr(obj, "__class__") and hasattr(obj.__class__, "__mro__"): + for base in obj.__class__.__mro__: + if base.__name__ == "ForgeActor": + return obj.__class__.__name__ + + # Check for 'cls' (class method calls) + if "cls" in frame.f_locals: + cls = frame.f_locals["cls"] + if hasattr(cls, "__mro__"): + for base in cls.__mro__: + if base.__name__ == "ForgeActor": + return cls.__name__ + + except Exception as e: + logger.debug(f"Call stack detection failed: {e}") + + return "UnknownActor" + + +def get_actor_name_with_rank(actor_name: Optional[str] = None) -> str: + """ + Extracts actor information from Monarch context to form a logging name. + + Args: + actor_name: Optional actor name to use. If None, will auto-detect from call stack. + + Returns: + str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). + Falls back to "UnknownActor" if context unavailable. + """ + ctx = context() + if ctx is None or ctx.actor_instance is None: + logger.warning("Context unavailable, using fallback actor name for logging.") + return "UnknownActor" + + actor_instance = ctx.actor_instance + rank = current_rank() + actor_id_full = str(actor_instance.actor_id) + + # Parse the actor_id + parts = actor_id_full.split(".") + if len(parts) < 2: + return "UnknownActor" + + world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" + actor_part = parts[1] # e.g., "TestActorConfigured[0]" + + # Use provided actor name or auto-detect from call stack + if actor_name: + final_actor_name = actor_name + else: + final_actor_name = detect_actor_name_from_call_stack() + + # Use last 4 characters of world_id as replica identifier + world_id = world_part.split("[")[0] if "[" in world_part else world_part + replica_id = world_id[-4:] if len(world_id) >= 4 else world_id + + return f"{final_actor_name}_{replica_id}_r{rank.rank}" diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index d999fb700..fb5030504 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -95,7 +95,7 @@ async def main(): } service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(config) # Spawn services first (triggers registrations via provisioner hook) diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 0f3ce662c..19b5621c1 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -33,7 +33,7 @@ async def run(cfg: DictConfig): ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) if (prompt := cfg.get("prompt")) is None: diff --git a/tests/unit_tests/observability/conftest.py b/tests/unit_tests/observability/conftest.py index a803c252d..8e28339f1 100644 --- a/tests/unit_tests/observability/conftest.py +++ b/tests/unit_tests/observability/conftest.py @@ -22,13 +22,14 @@ def __init__(self, logger_backend_config=None): self.finish_called = False self.metadata = {} - async def init(self, role="local", primary_logger_metadata=None): + async def init(self, role="local", primary_logger_metadata=None, process_name=None): self.init_called = True self.role = role self.primary_logger_metadata = primary_logger_metadata or {} + self.process_name = process_name - async def log(self, metrics, step): - self.logged_metrics.append((metrics, step)) + async def log(self, metrics, global_step): + self.logged_metrics.append((metrics, global_step)) async def finish(self): self.finish_called = True diff --git a/tests/unit_tests/observability/test_metric_actors.py b/tests/unit_tests/observability/test_metric_actors.py new file mode 100644 index 000000000..71e34edb4 --- /dev/null +++ b/tests/unit_tests/observability/test_metric_actors.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Optimized unit tests for metric actors functionality.""" + +import pytest + +from forge.observability.metric_actors import ( + get_or_create_metric_logger, + GlobalLoggingActor, + LocalFetcherActor, +) +from monarch.actor import this_host + + +@pytest.fixture +def global_logger(): + """Create a GlobalLoggingActor for testing.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestGlobalLogger", GlobalLoggingActor) + + +@pytest.fixture +def local_fetcher(global_logger): + """Create a LocalFetcherActor linked to global logger.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestLocalFetcher", LocalFetcherActor, global_logger) + + +class TestBasicOperations: + """Test basic operations for actors.""" + + @pytest.mark.asyncio + async def test_local_fetcher_flush(self, local_fetcher): + """Test LocalFetcherActor flush operations.""" + result_with_state = await local_fetcher.flush.call_one( + global_step=1, return_state=True + ) + assert result_with_state == {} + + result_without_state = await local_fetcher.flush.call_one( + global_step=1, return_state=False + ) + assert result_without_state == {} + + @pytest.mark.asyncio + async def test_global_logger_basic_ops(self, global_logger): + """Test GlobalLoggingActor basic operations.""" + count = await global_logger.get_fetcher_count.call_one() + assert count >= 0 + + has_fetcher = await global_logger.has_fetcher.call_one("nonexistent") + assert has_fetcher is False + + # Global logger flush (should not raise error) + await global_logger.flush.call_one(global_step=1) + + @pytest.mark.asyncio + async def test_backend_init(self, local_fetcher): + """Test backend initialization and shutdown.""" + metadata = {"wandb": {"shared_run_id": "test123"}} + config = {"console": {"logging_mode": "per_rank_reduce"}} + + await local_fetcher.init_backends.call_one(metadata, config, global_step=5) + await local_fetcher.shutdown.call_one() + + +class TestRegistrationLifecycle: + """Test registration lifecycle.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_registration_lifecycle(self, global_logger, local_fetcher): + """Test complete registration/deregistration lifecycle.""" + proc_name = "lifecycle_test_proc" + + # Initial state + initial_count = await global_logger.get_fetcher_count.call_one() + assert await global_logger.has_fetcher.call_one(proc_name) is False + + # Register + await global_logger.register_fetcher.call_one(local_fetcher, proc_name) + + # Verify registered + new_count = await global_logger.get_fetcher_count.call_one() + assert new_count == initial_count + 1 + assert await global_logger.has_fetcher.call_one(proc_name) is True + + # Deregister + await global_logger.deregister_fetcher.call_one(proc_name) + + # Verify deregistered + final_count = await global_logger.get_fetcher_count.call_one() + assert final_count == initial_count + assert await global_logger.has_fetcher.call_one(proc_name) is False + + +class TestBackendConfiguration: + """Test backend configuration validation.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_valid_backend_configs(self, global_logger): + """Test valid backend configurations.""" + # Empty config + await global_logger.init_backends.call_one({}) + + # Valid configs for all logging modes + for mode in ["per_rank_reduce", "per_rank_no_reduce", "global_reduce"]: + config = {"console": {"logging_mode": mode}} + await global_logger.init_backends.call_one(config) + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_invalid_backend_configs(self, global_logger): + """Test invalid backend configurations are handled gracefully.""" + # Empty config should work + await global_logger.init_backends.call_one({}) + + # Config with only project should work + config_with_project = {"console": {"project": "test_project"}} + await global_logger.init_backends.call_one(config_with_project) + + # Config with reduce_across_ranks should work (Diff 3 doesn't validate logging_mode yet) + config_with_reduce = {"console": {"reduce_across_ranks": True}} + await global_logger.init_backends.call_one(config_with_reduce) + + +class TestErrorHandling: + """Test error handling scenarios.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_deregister_nonexistent_fetcher(self, global_logger): + """Test deregistering non-existent fetcher doesn't crash.""" + await global_logger.deregister_fetcher.call_one("nonexistent_proc") + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_shutdown(self, global_logger): + """Test shutdown without issues.""" + await global_logger.shutdown.call_one() + + +class TestGetOrCreateMetricLogger: + """Test the integration function.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_get_or_create_functionality(self): + """Test get_or_create_metric_logger basic functionality.""" + result = await get_or_create_metric_logger() + + # Should return a GlobalLoggingActor mesh + assert result is not None + + # Should be able to call basic methods + count = await result.get_fetcher_count.call_one() + assert count >= 0 diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index 563f52e6c..9ea8ade38 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -80,12 +80,9 @@ def test_new_enums_and_constants(self): assert isinstance(BackendRole.LOCAL, BackendRole) assert isinstance(BackendRole.GLOBAL, BackendRole) - @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_backend_role_usage(self, mock_actor_name): + async def test_backend_role_usage(self): """Test that BackendRole constants are actually used instead of string literals.""" - mock_actor_name.return_value = "TestActor_abcd_r0" - # Test ConsoleBackend console_backend = ConsoleBackend({}) await console_backend.init(role=BackendRole.LOCAL) @@ -295,10 +292,8 @@ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank): mock_collector_class.assert_called_once() mock_collector.push.assert_called_once() - @patch("forge.observability.metrics.get_actor_name_with_rank") - def test_wandb_backend_creation(self, mock_actor_name): + def test_wandb_backend_creation(self): """Test WandbBackend creation and basic setup without WandB dependency.""" - mock_actor_name.return_value = "TestActor_abcd_r0" config = { "project": "test_project", @@ -316,12 +311,9 @@ def test_wandb_backend_creation(self, mock_actor_name): metadata = backend.get_metadata_for_secondary_ranks() assert metadata == {} # Should be empty when no run - @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_console_backend(self, mock_actor_name): + async def test_console_backend(self): """Test ConsoleBackend basic operations.""" - mock_actor_name.return_value = "TestActor_abcd_r0" - backend = ConsoleBackend({}) await backend.init(role=BackendRole.LOCAL) From 5b477e8456a42c69abcf369435c87e1fb4b47e30 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 19:03:31 -0700 Subject: [PATCH 09/18] commit --- src/forge/observability/perf_tracker.py | 144 +++++++++--------- .../observability/test_perf_tracker.py | 26 +++- 2 files changed, 92 insertions(+), 78 deletions(-) diff --git a/src/forge/observability/perf_tracker.py b/src/forge/observability/perf_tracker.py index 47577d916..184d05c26 100644 --- a/src/forge/observability/perf_tracker.py +++ b/src/forge/observability/perf_tracker.py @@ -3,12 +3,12 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + import inspect import logging import os import threading import time - from concurrent.futures import Future, ThreadPoolExecutor from functools import lru_cache, wraps from typing import List, Optional, Protocol, Tuple @@ -18,6 +18,8 @@ from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU from forge.observability.metrics import record_metric, Reduce +logger = logging.getLogger(__name__) + # Thread-local memory tracking state _local = threading.local() @@ -44,7 +46,6 @@ def _warn_nested_memory_tracking(prefix: str) -> None: """ - class Tracer: ========== """ @@ -150,10 +151,9 @@ def stop(self) -> None: if not self._active: raise ValueError("Tracer must be started before calling stop") - # Stop timing (always enabled) - # step("end") is dropped from steps, but included in total sum - self._timer.step("end") # pyre-ignore - self._record_timing_metrics() + # Stop timing + durations, final_ms = self._timer.get_all_durations() # pyre-ignore + self._record_timing_metrics(durations, final_ms) self._timer = None # Stop memory tracking @@ -163,20 +163,17 @@ def stop(self) -> None: self._active = False def _start_memory_tracking(self) -> None: - is_outer_scope = not _is_memory_active() - should_track = ( - self.track_memory and is_outer_scope and torch.cuda.is_available() - ) + if not (self.track_memory and torch.cuda.is_available()): + return - if self.track_memory and not is_outer_scope: + if _is_memory_active(): _warn_nested_memory_tracking(self.prefix) return - if should_track: - _set_memory_active(True) - torch.cuda.reset_max_memory_allocated() - self._start_mem = torch.cuda.memory_allocated() - self._memory_started = True + _set_memory_active(True) + torch.cuda.reset_max_memory_allocated() + self._start_mem = torch.cuda.memory_allocated() + self._memory_started = True def _stop_memory_tracking(self) -> None: if not self._memory_started: @@ -193,17 +190,15 @@ def _stop_memory_tracking(self) -> None: torch.cuda.reset_max_memory_allocated() self._memory_started = False - def _record_timing_metrics(self) -> None: - durations = self._timer.get_all_durations() # pyre-ignore - - # Total: sum all recorded durations (full timeline including end) - total_ms = sum(d_ms for name, d_ms in durations) + def _record_timing_metrics( + self, durations: List[Tuple[str, float]], final_ms: float + ) -> None: + total_ms = sum(d_ms for _, d_ms in durations) + final_ms total_s = total_ms / 1000.0 record_metric(f"{self.prefix}/total_duration_avg_s", total_s, Reduce.MEAN) record_metric(f"{self.prefix}/total_duration_max_s", total_s, Reduce.MAX) - # Steps: record each individually (drop last "end") - for name, d_ms in durations[:-1]: + for name, d_ms in durations: d_s = d_ms / 1000.0 record_metric(f"{self.prefix}/{name}/duration_avg_s", d_s, Reduce.MEAN) record_metric(f"{self.prefix}/{name}/duration_max_s", d_s, Reduce.MAX) @@ -216,7 +211,7 @@ def start(self) -> None: def step(self, name: str) -> None: ... - def get_all_durations(self) -> List[Tuple[str, float]]: + def get_all_durations(self) -> Tuple[List[Tuple[str, float]], float]: ... @@ -242,13 +237,27 @@ def step(self, name: str) -> None: self._durations.append((name, delta_ms)) self._chain_start = now - def get_all_durations(self) -> List[Tuple[str, float]]: - return self._durations[:] + def get_all_durations(self) -> Tuple[List[Tuple[str, float]], float]: + """Retrieve list of (step_name, duration) tuples. + Also computes and returns final duration since last step (or start if none).""" + final_ms = 0.0 + if self._chain_start is not None: + now = time.perf_counter() + final_ms = (now - self._chain_start) * 1000 + return self._durations[:], final_ms class _TimerCUDA(_TimerProtocol): - """CUDA timing backend with non-blocking events and futures. - Uses a thread pool to poll CUDA events asynchronously without blocking the main thread. + """CUDA timing backend for Tracer: Chains events on current stream for precise GPU durations. + Steps submit non-blocking futures; polls async in another thread. + + Example: + timer = _TimerCUDA() + timer.start() + # torch.mm(a, b) # ~100ms GPU + timer.step("matmul") + # torch.mm(c, d) # ~200ms + durs_steps, final_step = timer.get_all_durations() # ([( "matmul", 100 )], 200) """ def __init__(self, max_workers: int = 2) -> None: @@ -277,7 +286,6 @@ def step(self, name: str) -> None: Args: name: Label for this segment's duration """ - # Submit polling future; chain to next event. if self._chain_start is None: raise ValueError("Timer must be started before calling step") @@ -285,66 +293,62 @@ def step(self, name: str) -> None: end_event = torch.cuda.Event(enable_timing=True) end_event.record(stream) - def _compute_elapsed(start_event, end_event): - # Poll with backoff: starts fast (1ms), grows to cap (50ms) for mixed workloads. - sleep_time = 0.001 # Start at 1ms - while not end_event.query(): - time.sleep(sleep_time) - sleep_time = min(sleep_time * 1.5, 0.05) # Backoff, cap at 50ms - return start_event.elapsed_time(end_event) - - future = self._executor.submit(_compute_elapsed, self._chain_start, end_event) + future = self._executor.submit(self._poll_elapsed, self._chain_start, end_event) index = len(self._futures) self._futures.append((name, future, index)) - if len(self._futures) >= 5: # clean up every 5 self._collect_completed_futures() self._chain_start = end_event - def _collect_completed_futures(self) -> None: + def _poll_elapsed( + self, start_event: torch.cuda.Event, end_event: torch.cuda.Event + ) -> float: + """Compute elapsed time after polling with backoff.""" + # Poll until ready + sleep_time = 0.001 # Start at 1ms + while not end_event.query(): + time.sleep(sleep_time) + sleep_time = min(sleep_time * 1.5, 0.05) # Backoff, cap at 50ms + return start_event.elapsed_time(end_event) + + def _collect_completed_futures(self, wait_till_done: bool = False) -> None: """Drain done futures to avoid memory leak; update durations in submission order.""" - completed = [] still_pending = [] for name, future, idx in self._futures: - if future.done(): - try: - dur = future.result() - completed.append((idx, name, dur)) - except Exception as e: - raise RuntimeError(f"Timing failed for {name}: {e}") from e + if future.done() or wait_till_done: + dur = future.result() + self._durations.append((name, dur)) else: still_pending.append((name, future, idx)) - # Sort completed by submission index to preserve order - completed.sort(key=lambda x: x[0]) - for _, name, dur in completed: - self._durations.append((name, dur)) - self._futures = still_pending - def get_all_durations(self) -> List[Tuple[str, float]]: - """Retrieve list of (name, duration) tuples in submission order after waiting for background polls to finish.""" - # Wait and collect if pendings; return durations. - self._collect_completed_futures() - completed = [] - for name, future, idx in self._futures: - try: - dur = future.result() - completed.append((idx, name, dur)) - except Exception as e: - raise RuntimeError(f"Timing failed for {name}: {e}") from e - - # Sort by submission index to preserve order - completed.sort(key=lambda x: x[0]) - for _, name, dur in completed: - self._durations.append((name, dur)) + def get_all_durations(self) -> Tuple[List[Tuple[str, float]], float]: + """Retrieve list of (step_name, duration) tuples in random order after waiting for background polls to finish. + Also computes and returns final duration since last step (or start if none).""" + # Submit final as a special step (reuses step logic; no collect needed here) + stop_step = f"_final_internal_{id(self)}" + self.step(stop_step) + # Wait on remaining futures + self._collect_completed_futures(wait_till_done=True) self._futures.clear() - return self._durations[:] + + # Extract final_ms + final_ms = 0.0 + durations = [ + (name, duration) for name, duration in self._durations if name != stop_step + ] + for name, duration in self._durations: + if name == stop_step: + final_ms = duration + break + + return durations, final_ms def __del__(self) -> None: - # Fallback cleanup in finalizer; ignores errors to avoid shutdown noise. + # Fallback cleanup in finalizer try: self._executor.shutdown(wait=True) except Exception: diff --git a/tests/unit_tests/observability/test_perf_tracker.py b/tests/unit_tests/observability/test_perf_tracker.py index 01d1603d1..97d7e105d 100644 --- a/tests/unit_tests/observability/test_perf_tracker.py +++ b/tests/unit_tests/observability/test_perf_tracker.py @@ -309,29 +309,39 @@ def test_tracer_and_timer_reuse(self, mock_record_metric_calls): cpu_timer.start() time.sleep(0.005) cpu_timer.step("cpu_step1") - durations1 = cpu_timer.get_all_durations() + cpu_durations_list1, cpu_final_ms1 = cpu_timer.get_all_durations() cpu_timer.start() time.sleep(0.005) cpu_timer.step("cpu_step2") - durations2 = cpu_timer.get_all_durations() + cpu_durations_list2, cpu_final_ms2 = cpu_timer.get_all_durations() - assert len(durations1) == 1 and durations1[0][0] == "cpu_step1" - assert len(durations2) == 1 and durations2[0][0] == "cpu_step2" + assert ( + len(cpu_durations_list1) == 1 and cpu_durations_list1[0][0] == "cpu_step1" + ) + assert ( + len(cpu_durations_list2) == 1 and cpu_durations_list2[0][0] == "cpu_step2" + ) # Test CUDA timer reuse (if available) if torch.cuda.is_available(): cuda_timer = _TimerCUDA() cuda_timer.start() cuda_timer.step("cuda_step1") - cuda_durations1 = cuda_timer.get_all_durations() + cuda_durations_list1, cuda_final_ms1 = cuda_timer.get_all_durations() cuda_timer.start() cuda_timer.step("cuda_step2") - cuda_durations2 = cuda_timer.get_all_durations() + cuda_durations_list2, cuda_final_ms2 = cuda_timer.get_all_durations() - assert len(cuda_durations1) == 1 and cuda_durations1[0][0] == "cuda_step1" - assert len(cuda_durations2) == 1 and cuda_durations2[0][0] == "cuda_step2" + assert ( + len(cuda_durations_list1) == 1 + and cuda_durations_list1[0][0] == "cuda_step1" + ) + assert ( + len(cuda_durations_list2) == 1 + and cuda_durations_list2[0][0] == "cuda_step2" + ) def test_exception_handling_context_manager(self, mock_record_metric_calls): """Test context manager properly cleans up on exception.""" From 471b88aec2af446d81ebdb2910958392cf97d89d Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 19:36:04 -0700 Subject: [PATCH 10/18] revert --- src/forge/observability/perf_tracker.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/forge/observability/perf_tracker.py b/src/forge/observability/perf_tracker.py index 8342a0168..834688436 100644 --- a/src/forge/observability/perf_tracker.py +++ b/src/forge/observability/perf_tracker.py @@ -163,17 +163,20 @@ def stop(self) -> None: self._active = False def _start_memory_tracking(self) -> None: - if not (self.track_memory and torch.cuda.is_available()): - return + is_outer_scope = not _is_memory_active() + should_track = ( + self.track_memory and is_outer_scope and torch.cuda.is_available() + ) - if _is_memory_active(): + if self.track_memory and not is_outer_scope: _warn_nested_memory_tracking(self.prefix) return - _set_memory_active(True) - torch.cuda.reset_max_memory_allocated() - self._start_mem = torch.cuda.memory_allocated() - self._memory_started = True + if should_track: + _set_memory_active(True) + torch.cuda.reset_max_memory_allocated() + self._start_mem = torch.cuda.memory_allocated() + self._memory_started = True def _stop_memory_tracking(self) -> None: if not self._memory_started: From fa4895f5e95b7bf53c1ababd6c8e6975eddea658 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 20:22:02 -0700 Subject: [PATCH 11/18] remove unnecessary code --- src/forge/controller/provisioner.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 21540d17e..258849429 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -264,10 +264,8 @@ def bootstrap(env: dict[str, str]): # Spawn local fetcher actor on each process and register with global logger if os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true": from forge.observability.metric_actors import get_or_create_metric_logger - from forge.observability.utils import detect_actor_name_from_call_stack - process_name = detect_actor_name_from_call_stack() - _ = await get_or_create_metric_logger(procs, process_name=process_name) + _ = await get_or_create_metric_logger(procs) return procs async def host_mesh_from_proc(self, proc_mesh: ProcMesh): From 7bb1fe7a13e24d3b1f863455b5ba87c2e26611ad Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 20:24:33 -0700 Subject: [PATCH 12/18] better logging --- src/forge/observability/metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index cc308432c..8be29e6d2 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -536,7 +536,7 @@ async def shutdown(self): """Shutdown logger_backends if initialized.""" if not self._is_initialized: logger.debug( - f"Collector for rank {self.rank} not initialized. Skipping shutdown" + f"Collector for rank {get_actor_name_with_rank()} not initialized. Skipping shutdown" ) return From 43d5d27c9702b7259fa0a71b03444cf410b25b15 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 9 Oct 2025 07:23:11 -0700 Subject: [PATCH 13/18] docs/names --- src/forge/observability/perf_tracker.py | 41 +++++++++++++------------ 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/src/forge/observability/perf_tracker.py b/src/forge/observability/perf_tracker.py index 834688436..12c5e904a 100644 --- a/src/forge/observability/perf_tracker.py +++ b/src/forge/observability/perf_tracker.py @@ -152,8 +152,8 @@ def stop(self) -> None: raise ValueError("Tracer must be started before calling stop") # Stop timing - durations, final_ms = self._timer.get_all_durations() # pyre-ignore - self._record_timing_metrics(durations, final_ms) + durations, stop_step_ms = self._timer.get_all_durations() # pyre-ignore + self._record_timing_metrics(durations, stop_step_ms) self._timer = None # Stop memory tracking @@ -194,9 +194,9 @@ def _stop_memory_tracking(self) -> None: self._memory_started = False def _record_timing_metrics( - self, durations: list[tuple[str, float]], final_ms: float + self, durations: list[tuple[str, float]], stop_step_ms: float ) -> None: - total_ms = sum(d_ms for _, d_ms in durations) + final_ms + total_ms = sum(d_ms for _, d_ms in durations) + stop_step_ms total_s = total_ms / 1000.0 record_metric(f"{self.prefix}/total_duration_avg_s", total_s, Reduce.MEAN) record_metric(f"{self.prefix}/total_duration_max_s", total_s, Reduce.MAX) @@ -241,18 +241,18 @@ def step(self, name: str) -> None: self._chain_start = now def get_all_durations(self) -> tuple[list[tuple[str, float]], float]: - """Retrieve list of (step_name, duration) tuples. - Also computes and returns final duration since last step (or start if none).""" - final_ms = 0.0 + """Retrieve list of (step_name, duration) tuples and last step duration + between tracer.stop and the last step (or start if none).""" + stop_step_ms = 0.0 if self._chain_start is not None: now = time.perf_counter() - final_ms = (now - self._chain_start) * 1000 - return self._durations[:], final_ms + stop_step_ms = (now - self._chain_start) * 1000 + return self._durations[:], stop_step_ms class _TimerCUDA(_TimerProtocol): - """CUDA timing backend for Tracer: Chains events on current stream for precise GPU durations. - Steps submit non-blocking futures; polls async in another thread. + """CUDA timing backend with non-blocking events and futures. + Uses a thread pool to poll CUDA events asynchronously without blocking the main thread. Example: timer = _TimerCUDA() @@ -260,7 +260,7 @@ class _TimerCUDA(_TimerProtocol): # torch.mm(a, b) # ~100ms GPU timer.step("matmul") # torch.mm(c, d) # ~200ms - durs_steps, final_step = timer.get_all_durations() # ([( "matmul", 100 )], 200) + durs_steps, stop_step_ms = timer.get_all_durations() # ([( "matmul", 100 )], 200) """ def __init__(self, max_workers: int = 2) -> None: @@ -328,27 +328,28 @@ def _collect_completed_futures(self, wait_till_done: bool = False) -> None: self._futures = still_pending def get_all_durations(self) -> tuple[list[tuple[str, float]], float]: - """Retrieve list of (step_name, duration) tuples in random order after waiting for background polls to finish. - Also computes and returns final duration since last step (or start if none).""" - # Submit final as a special step (reuses step logic; no collect needed here) - stop_step = f"_final_internal_{id(self)}" + """Retrieve list of (step_name, duration) tuples and last step duration + between tracer.stop and the last step (or start if none). Order of tuples is random. + """ + # Final timing since last step (or start) until this function is called + stop_step = f"_stop_step_{id(self)}" self.step(stop_step) # Wait on remaining futures self._collect_completed_futures(wait_till_done=True) self._futures.clear() - # Extract final_ms - final_ms = 0.0 + # Extract stop_step_ms + stop_step_ms = 0.0 durations = [ (name, duration) for name, duration in self._durations if name != stop_step ] for name, duration in self._durations: if name == stop_step: - final_ms = duration + stop_step_ms = duration break - return durations, final_ms + return durations, stop_step_ms def __del__(self) -> None: # Fallback cleanup in finalizer From 1186aec158129da573e521352d247d5d31af4479 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Thu, 9 Oct 2025 14:14:46 -0700 Subject: [PATCH 14/18] update cfg back to true --- apps/grpo/qwen3_1_7b.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 0e87cc6cf..53eec5cfb 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -18,9 +18,9 @@ metric_logging: wandb: project: "grpo-training" group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: False + reduce_across_ranks: True console: - reduce_across_ranks: False + reduce_across_ranks: True # Dataset configuration dataset: From 370c4e43bf3fad78e41a7333b34f145fbb1aaf7c Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Mon, 13 Oct 2025 18:40:12 -0700 Subject: [PATCH 15/18] remove callstack, get meshname in provisioner --- src/forge/controller/provisioner.py | 2 +- src/forge/observability/__init__.py | 3 +- src/forge/observability/metric_actors.py | 3 +- src/forge/observability/utils.py | 63 +++--------------------- tests/sandbox/toy_rl/toy_metrics/main.py | 8 ++- 5 files changed, 17 insertions(+), 62 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 1bb340328..34b000e89 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -316,7 +316,7 @@ def bootstrap(env: dict[str, str]): if not FORGE_DISABLE_METRICS.get_value(): from forge.observability.metric_actors import get_or_create_metric_logger - _ = await get_or_create_metric_logger(procs) + _ = await get_or_create_metric_logger(procs, process_name=mesh_name) return procs async def host_mesh_from_proc(self, proc_mesh: ProcMesh): diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 425d96bcd..1b04f76c3 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -27,7 +27,7 @@ WandbBackend, ) from .perf_tracker import trace, Tracer -from .utils import detect_actor_name_from_call_stack, get_actor_name_with_rank +from .utils import get_actor_name_with_rank __all__ = [ # Main API functions @@ -44,7 +44,6 @@ # Enums "Reduce", # Utility functions - "detect_actor_name_from_call_stack", "get_actor_name_with_rank", # Actor classes "GlobalLoggingActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 795c96cc7..e5099f3bb 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -18,7 +18,6 @@ MetricCollector, reduce_metrics_states, ) -from forge.observability.utils import detect_actor_name_from_call_stack if MONARCH_HOSTMESH_V1.get_value(): from monarch._src.actor.v1.host_mesh import this_proc @@ -85,7 +84,7 @@ async def get_or_create_metric_logger( """ if process_name is None: - process_name = detect_actor_name_from_call_stack() + process_name = "UnknownActor" # Get or create the singleton global logger global _global_logger diff --git a/src/forge/observability/utils.py b/src/forge/observability/utils.py index f9fc18014..7f6522cfe 100644 --- a/src/forge/observability/utils.py +++ b/src/forge/observability/utils.py @@ -12,64 +12,24 @@ logger = logging.getLogger(__name__) -def detect_actor_name_from_call_stack() -> str: - """Detect ForgeActor subclass name from call stack. - - Returns: - str: Actor name, defaulting to "UnknownActor" if not found. - """ - try: - import inspect - - frame = inspect.currentframe() - frame_count = 0 - - while frame: - frame = frame.f_back - if not frame: - break - - frame_count += 1 - if frame_count > 20: # Prevent infinite loops - break - - # Check for 'self' (instance method calls) - if "self" in frame.f_locals: - obj = frame.f_locals["self"] - if hasattr(obj, "__class__") and hasattr(obj.__class__, "__mro__"): - for base in obj.__class__.__mro__: - if base.__name__ == "ForgeActor": - return obj.__class__.__name__ - - # Check for 'cls' (class method calls) - if "cls" in frame.f_locals: - cls = frame.f_locals["cls"] - if hasattr(cls, "__mro__"): - for base in cls.__mro__: - if base.__name__ == "ForgeActor": - return cls.__name__ - - except Exception as e: - logger.debug(f"Call stack detection failed: {e}") - - return "UnknownActor" - - def get_actor_name_with_rank(actor_name: Optional[str] = None) -> str: """ Extracts actor information from Monarch context to form a logging name. Args: - actor_name: Optional actor name to use. If None, will auto-detect from call stack. + actor_name: Actor name to use. Defaults to "UnknownActor" if None. Returns: str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). - Falls back to "UnknownActor" if context unavailable. + Falls back to just actor name if context unavailable. """ + if actor_name is None: + actor_name = "UnknownActor" + ctx = context() if ctx is None or ctx.actor_instance is None: logger.warning("Context unavailable, using fallback actor name for logging.") - return "UnknownActor" + return actor_name actor_instance = ctx.actor_instance rank = current_rank() @@ -78,19 +38,12 @@ def get_actor_name_with_rank(actor_name: Optional[str] = None) -> str: # Parse the actor_id parts = actor_id_full.split(".") if len(parts) < 2: - return "UnknownActor" + return f"{actor_name}_r{rank.rank}" world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" - actor_part = parts[1] # e.g., "TestActorConfigured[0]" - - # Use provided actor name or auto-detect from call stack - if actor_name: - final_actor_name = actor_name - else: - final_actor_name = detect_actor_name_from_call_stack() # Use last 4 characters of world_id as replica identifier world_id = world_part.split("[")[0] if "[" in world_part else world_part replica_id = world_id[-4:] if len(world_id) >= 4 else world_id - return f"{final_actor_name}_{replica_id}_r{rank.rank}" + return f"{actor_name}_{replica_id}_r{rank.rank}" diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index fb5030504..29164b38f 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -99,8 +99,12 @@ async def main(): await mlogger.init_backends.call_one(config) # Spawn services first (triggers registrations via provisioner hook) - trainer = await TrainActor.options(**service_config).as_service() - generator = await GeneratorActor.options(**service_config).as_service() + trainer = await TrainActor.options( + **service_config, mesh_name="TrainActor" + ).as_service() + generator = await GeneratorActor.options( + **service_config, mesh_name="GeneratorActor" + ).as_service() for i in range(3): print(f"\n=== Global Step {i} ===") From 9e779302d73af6bb25425f13a6ac1a55e6ad6bc6 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 14 Oct 2025 07:27:19 -0700 Subject: [PATCH 16/18] get name from proc mesh --- src/forge/controller/provisioner.py | 1 + src/forge/observability/metric_actors.py | 10 ++++++---- src/forge/observability/metrics.py | 8 +++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 34b000e89..0bf61f7f1 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -304,6 +304,7 @@ def bootstrap(env: dict[str, str]): self._host_mesh_map[mesh_name] = host_mesh procs._host = host_mesh + procs._mesh_name = mesh_name # If we created a server, track so we can tear it down later. if server_name: diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index e5099f3bb..9ad98fde5 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -48,7 +48,8 @@ async def get_or_create_metric_logger( proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `monarch.actor.this_proc()`. process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging. - If None, will auto-detect from call stack or default to "UnknownActor" if not found. + If None, will be auto-detected from the mesh_name provided during actor initialization or + a generic mesh name if one was not provided. Returns: GlobalLoggingActor: The global logging controller. @@ -83,9 +84,6 @@ async def get_or_create_metric_logger( await mlogger.shutdown() """ - if process_name is None: - process_name = "UnknownActor" - # Get or create the singleton global logger global _global_logger if _global_logger is None: @@ -97,6 +95,10 @@ async def get_or_create_metric_logger( # Determine process context proc = proc_mesh if proc_mesh is not None else this_proc() + # Auto-detect process_name from proc mesh if not provided + if process_name is None: + process_name = proc._mesh_name + # Check current state for consistency proc_has_local_fetcher = hasattr(proc, "_local_fetcher") global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 8be29e6d2..f54b863ff 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -396,6 +396,7 @@ def __init__(self) -> None: self.rank = current_rank().rank self.logger_backends: list[LoggerBackend] = [] self._is_initialized = False + self.process_name: str | None = None async def init_backends( self, @@ -417,9 +418,10 @@ async def init_backends( """ if self._is_initialized: logger.debug( - f"Rank {get_actor_name_with_rank()}: MetricCollector already initialized" + f"Rank {get_actor_name_with_rank(self.process_name)}: MetricCollector already initialized" ) return + self.process_name = process_name self.global_step = global_step # instantiate local backends if any @@ -511,7 +513,7 @@ async def flush( if not self.accumulators: logger.debug( - f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for global_step {global_step}" + f"Collector rank {get_actor_name_with_rank(self.process_name)}: No metrics to flush for global_step {global_step}" ) return {} @@ -536,7 +538,7 @@ async def shutdown(self): """Shutdown logger_backends if initialized.""" if not self._is_initialized: logger.debug( - f"Collector for rank {get_actor_name_with_rank()} not initialized. Skipping shutdown" + f"Collector for rank {get_actor_name_with_rank(self.process_name)} not initialized. Skipping shutdown" ) return From 93b0cad9fa16910338556f88e228fb4153a8da7c Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Tue, 14 Oct 2025 12:55:42 -0700 Subject: [PATCH 17/18] simplify + unit tests --- src/forge/actors/policy.py | 4 +- src/forge/controller/provisioner.py | 1 - src/forge/observability/__init__.py | 4 +- src/forge/observability/metric_actors.py | 5 +- src/forge/observability/metrics.py | 12 ++-- src/forge/observability/utils.py | 63 ++++++++++--------- tests/unit_tests/observability/__init__.py | 5 ++ .../observability/test_metric_actors.py | 2 +- .../unit_tests/observability/test_metrics.py | 6 +- tests/unit_tests/observability/test_utils.py | 60 ++++++++++++++++++ 10 files changed, 117 insertions(+), 45 deletions(-) create mode 100644 tests/unit_tests/observability/__init__.py create mode 100644 tests/unit_tests/observability/test_utils.py diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 3a1b3e86e..1579fea40 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -341,8 +341,8 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: def _preprocess_add_request( self, request: EngineCoreRequest ) -> tuple[Request, int]: - """ (forge/issues/332) Will require attention when we bump vllm versions - https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419""" + """(forge/issues/332) Will require attention when we bump vllm versions + https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419""" if request.mm_hashes is not None: raise NotImplementedError("Support for mm_hash is not implemented yet.") req = Request.from_engine_core_request(request) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 0bf61f7f1..34b000e89 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -304,7 +304,6 @@ def bootstrap(env: dict[str, str]): self._host_mesh_map[mesh_name] = host_mesh procs._host = host_mesh - procs._mesh_name = mesh_name # If we created a server, track so we can tear it down later. if server_name: diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 1b04f76c3..8efd3dace 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -27,7 +27,7 @@ WandbBackend, ) from .perf_tracker import trace, Tracer -from .utils import get_actor_name_with_rank +from .utils import get_proc_name_with_rank __all__ = [ # Main API functions @@ -44,7 +44,7 @@ # Enums "Reduce", # Utility functions - "get_actor_name_with_rank", + "get_proc_name_with_rank", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 9ad98fde5..0e29f5e19 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -8,7 +8,7 @@ import logging from typing import Any, Union -from monarch.actor import Actor, endpoint, ProcMesh +from monarch.actor import Actor, context, endpoint, ProcMesh from forge.env import FORGE_DISABLE_METRICS, MONARCH_HOSTMESH_V1 from forge.observability.metrics import ( @@ -97,7 +97,8 @@ async def get_or_create_metric_logger( # Auto-detect process_name from proc mesh if not provided if process_name is None: - process_name = proc._mesh_name + ctx = context() + process_name = ctx.actor_instance.actor_id.actor_name # Check current state for consistency proc_has_local_fetcher = hasattr(proc, "_local_fetcher") diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index f54b863ff..af0c154e2 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -15,7 +15,7 @@ import pytz from monarch.actor import current_rank -from forge.observability.utils import get_actor_name_with_rank +from forge.observability.utils import get_proc_name_with_rank from forge.util.logging import log_once logger = logging.getLogger(__name__) @@ -418,7 +418,7 @@ async def init_backends( """ if self._is_initialized: logger.debug( - f"Rank {get_actor_name_with_rank(self.process_name)}: MetricCollector already initialized" + f"{get_proc_name_with_rank(self.process_name)}: MetricCollector already initialized" ) return self.process_name = process_name @@ -513,7 +513,7 @@ async def flush( if not self.accumulators: logger.debug( - f"Collector rank {get_actor_name_with_rank(self.process_name)}: No metrics to flush for global_step {global_step}" + f"Collector for {get_proc_name_with_rank(self.process_name)}: No metrics to flush for global_step {global_step}" ) return {} @@ -538,7 +538,7 @@ async def shutdown(self): """Shutdown logger_backends if initialized.""" if not self._is_initialized: logger.debug( - f"Collector for rank {get_actor_name_with_rank(self.process_name)} not initialized. Skipping shutdown" + f"Collector for rank {get_proc_name_with_rank(self.process_name)} not initialized. Skipping shutdown" ) return @@ -609,7 +609,7 @@ async def init( primary_logger_metadata: dict[str, Any] | None = None, process_name: str | None = None, ) -> None: - self.prefix = get_actor_name_with_rank(actor_name=process_name) + self.prefix = get_proc_name_with_rank(proc_name=process_name) async def log(self, metrics: list[Metric], global_step: int) -> None: metrics_str = "\n".join( @@ -663,7 +663,7 @@ async def init( if primary_logger_metadata is None: primary_logger_metadata = {} - self.name = get_actor_name_with_rank(actor_name=process_name) + self.name = get_proc_name_with_rank(proc_name=process_name) # Default global mode: only inits on controller if self.reduce_across_ranks: diff --git a/src/forge/observability/utils.py b/src/forge/observability/utils.py index 7f6522cfe..4a45274e3 100644 --- a/src/forge/observability/utils.py +++ b/src/forge/observability/utils.py @@ -7,43 +7,48 @@ import logging from typing import Optional -from monarch.actor import context, current_rank +from monarch.actor import context logger = logging.getLogger(__name__) -def get_actor_name_with_rank(actor_name: Optional[str] = None) -> str: +def get_proc_name_with_rank(proc_name: Optional[str] = None) -> str: """ - Extracts actor information from Monarch context to form a logging name. + Returns a unique process identifier from Monarch actor context. + + Format: "ActorName_wxyz_r{rank}" where: + - ActorName: The actor class name (e.g., "TrainActor") + - wxyz: Last 4 chars of world_name (unique replica hash) + - rank: Local rank within the replica (0, 1, 2, ...) + + Note: If called from a direct proccess, defaults to "client_DPROC_r0". Args: - actor_name: Actor name to use. Defaults to "UnknownActor" if None. + proc_name: Optional override for actor name. If None, uses actor_id.actor_name. Returns: - str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). - Falls back to just actor name if context unavailable. + str: Unique identifier or fallback name if no context available. """ - if actor_name is None: - actor_name = "UnknownActor" - ctx = context() - if ctx is None or ctx.actor_instance is None: - logger.warning("Context unavailable, using fallback actor name for logging.") - return actor_name - - actor_instance = ctx.actor_instance - rank = current_rank() - actor_id_full = str(actor_instance.actor_id) - - # Parse the actor_id - parts = actor_id_full.split(".") - if len(parts) < 2: - return f"{actor_name}_r{rank.rank}" - - world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" - - # Use last 4 characters of world_id as replica identifier - world_id = world_part.split("[")[0] if "[" in world_part else world_part - replica_id = world_id[-4:] if len(world_id) >= 4 else world_id - - return f"{actor_name}_{replica_id}_r{rank.rank}" + actor_id = ctx.actor_instance.actor_id + + # Use actor_name from actor_id if not provided + if proc_name is None: + proc_name = actor_id.actor_name + + # Try to get world_name. Each replica has a unique value. + try: + world_name = actor_id.world_name + replica_id = world_name[-4:] if len(world_name) >= 4 else world_name + except BaseException: # Catches pyo3_runtime.PanicException from Rust + # Direct proc (e.g., client) - no world_name available + replica_id = "DPROC" + + # Get rank within the replica. NOT a global rank. + try: + rank = actor_id.rank + except BaseException: # Catches pyo3_runtime.PanicException from Rust + # Direct proc - no rank available + rank = 0 + + return f"{proc_name}_{replica_id}_r{rank}" diff --git a/tests/unit_tests/observability/__init__.py b/tests/unit_tests/observability/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/unit_tests/observability/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/unit_tests/observability/test_metric_actors.py b/tests/unit_tests/observability/test_metric_actors.py index 71e34edb4..501e13afe 100644 --- a/tests/unit_tests/observability/test_metric_actors.py +++ b/tests/unit_tests/observability/test_metric_actors.py @@ -152,7 +152,7 @@ class TestGetOrCreateMetricLogger: @pytest.mark.asyncio async def test_get_or_create_functionality(self): """Test get_or_create_metric_logger basic functionality.""" - result = await get_or_create_metric_logger() + result = await get_or_create_metric_logger(process_name="TestController") # Should return a GlobalLoggingActor mesh assert result is not None diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index 2d51e0a5f..d0f104459 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -417,8 +417,10 @@ async def _test_fetcher_registration(self, env_var_value, should_register_fetche if hasattr(procs, "_local_fetcher"): delattr(procs, "_local_fetcher") - # Test functionality - global_logger = await get_or_create_metric_logger(proc_mesh=procs) + # Test functionality - pass explicit process_name since test bypasses provisioner + global_logger = await get_or_create_metric_logger( + proc_mesh=procs, process_name="TestProcess" + ) # Get results to check proc_has_fetcher = hasattr(procs, "_local_fetcher") diff --git a/tests/unit_tests/observability/test_utils.py b/tests/unit_tests/observability/test_utils.py new file mode 100644 index 000000000..9a4e24d0c --- /dev/null +++ b/tests/unit_tests/observability/test_utils.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for observability utility functions.""" + +import pytest + +from forge.observability.utils import get_proc_name_with_rank +from monarch.actor import Actor, endpoint, this_host + + +class UtilActor(Actor): + """Actor for testing get_proc_name_with_rank in spawned context.""" + + @endpoint + async def get_name(self) -> str: + return get_proc_name_with_rank() + + @endpoint + async def get_name_with_override(self, name: str) -> str: + return get_proc_name_with_rank(proc_name=name) + + +class TestGetProcNameWithRank: + """Tests for get_proc_name_with_rank utility.""" + + def test_direct_proc(self): + """Direct proc (test process) should return client_DPROC_r0.""" + result = get_proc_name_with_rank() + assert result == "client_DPROC_r0" + + def test_direct_proc_with_override(self): + """Direct proc with override should use provided name.""" + result = get_proc_name_with_rank(proc_name="MyProcess") + assert result == "MyProcess_DPROC_r0" + + @pytest.mark.timeout(10) + @pytest.mark.asyncio + async def test_spawned_actor(self): + """Spawned actor should return ActorName_replica_rank format.""" + p = this_host().spawn_procs(per_host={"cpus": 2}) + actor = p.spawn("UtilActor", UtilActor) + + # no override + results = await actor.get_name.call() + + assert len(results) == 2 + for i, (rank_info, result) in enumerate(results): + replica_id = result.split("_")[1] + assert result == f"UtilActor_{replica_id}_r{i}" + + # override name + results = await actor.get_name_with_override.call("CustomName") + + for i, (rank_info, result) in enumerate(results): + replica_id = result.split("_")[1] + assert result == f"CustomName_{replica_id}_r{i}" From e901ad54ea586a8afd9e65f36efd512b05dcd7fb Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 15 Oct 2025 08:54:24 -0700 Subject: [PATCH 18/18] address comments --- src/forge/observability/metric_actors.py | 2 +- src/forge/observability/metrics.py | 8 +------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index a5f5f6677..f053d6a56 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -108,7 +108,7 @@ async def get_or_create_metric_logger( # Consistency check: both should be in sync if proc_has_local_fetcher != global_logger_has_local_fetcher: raise ValueError( - f"Inconsistent logging state for proc {proc}: " + f"Inconsistent logging state for {proc=} with {process_name=}: " f"proc has _local_fetcher={proc_has_local_fetcher}, " f"but global_logger has registration={global_logger_has_local_fetcher}. " f"This indicates a bug in logging setup/teardown. " diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index af0c154e2..4996b3a7f 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -55,13 +55,7 @@ def accumulator_class(self): class Metric: """Container for metric data including key, value, reduction type, and timestamp. - Timestamp is automatically set to current EST time if not provided. - - Args: - key: str - value: Any - reduction: Reduce - timestamp: Optional[float] = None + Timestamp is automatically set to current UTC time if not provided. """ key: str