From 77488cfe630be4390593ee16d6ca0d24dc67f93f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 08:38:55 -0700 Subject: [PATCH 1/7] commit --- src/forge/controller/provisioner.py | 3 +- src/forge/env_constants.py | 1 + src/forge/observability/metric_actors.py | 9 +- src/forge/observability/metrics.py | 37 ++- tests/unit_tests/observability/conftest.py | 95 +++++++ .../unit_tests/observability/test_metrics.py | 231 ++++++++++++++++++ 6 files changed, 370 insertions(+), 6 deletions(-) create mode 100644 tests/unit_tests/observability/conftest.py create mode 100644 tests/unit_tests/observability/test_metrics.py diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index c823afb29..429c5760f 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -262,7 +262,8 @@ def bootstrap(env: dict[str, str]): self._proc_host_map[procs] = host_mesh - # Spawn local logging actor on each process and register with global logger + # Spawn local fetcher actor on each process and register with global logger + # Can be disabled by FORGE_DISABLE_METRICS env var _ = await get_or_create_metric_logger(procs) return procs diff --git a/src/forge/env_constants.py b/src/forge/env_constants.py index 3adcdfc41..a4e024d83 100644 --- a/src/forge/env_constants.py +++ b/src/forge/env_constants.py @@ -14,4 +14,5 @@ METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA" # Makes forge.observability.metrics.record_metric a no-op +# and disables spawning LocalFetcherActor in get_or_create_metric_logger FORGE_DISABLE_METRICS = "FORGE_DISABLE_METRICS" diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index d67a66a83..edd1f24d8 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -6,10 +6,12 @@ import asyncio import logging +import os from typing import Any, Dict, Optional from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc +from forge.env_constants import FORGE_DISABLE_METRICS from forge.observability.metrics import ( get_logger_backend_class, LoggerBackend, @@ -95,8 +97,11 @@ async def get_or_create_metric_logger( f"Both should be True (already setup) or both False (needs setup)." ) - # Setup local_fetcher_actor if needed - if not proc_has_local_fetcher: + # Setup local_fetcher_actor if needed (unless disabled by environment flag) + if ( + not proc_has_local_fetcher + and os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true" + ): local_fetcher_actor = proc.spawn( "local_fetcher_actor", LocalFetcherActor, global_logger ) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 990a301e0..4d527ec1b 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -437,7 +437,21 @@ async def init_backends( def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: if not self._is_initialized: - raise ValueError("Collector not initialized—call init first") + from forge.util.logging import log_once + + log_once( + logger, + level=logging.WARNING, + msg=( + "Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized." + " This happens when you try to use `record_metric` before calling `init_backends`." + " To disable this warning, please call in your main file:\n" + "`mlogger = await get_or_create_metric_logger()`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" + "or set env variable `FORGE_DISABLE_METRICS=True`" + ), + ) + return if key not in self.accumulators: self.accumulators[key] = reduction.accumulator_class(reduction) @@ -458,8 +472,16 @@ async def flush( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: - logger.debug( - f"Collector not yet initialized for {get_actor_name_with_rank()}. Call init_backends first." + from forge.util.logging import log_once + + log_once( + logger, + level=logging.WARNING, + msg="Cannot flush collected metrics. MetricCollector.flush() called before init_backends()." + "\nPlease call in your main file:\n" + "`mlogger = await get_or_create_metric_logger()`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" + "before calling `flush`", ) return {} @@ -662,6 +684,15 @@ async def _init_shared_local(self, primary_metadata: Dict[str, Any]): raise ValueError( f"Shared ID required but not provided for {self.name} backend init" ) + + # Clear any stale service tokens that might be pointing to dead processes + # In multiprocessing environments, WandB service tokens can become stale and point + # to dead service processes. This causes wandb.init() to hang indefinitely trying + # to connect to non-existent services. Clearing forces fresh service connection. + from wandb.sdk.lib.service import service_token + + service_token.clear_service_in_env() + settings = wandb.Settings(mode="shared", x_primary=False, x_label=self.name) self.run = wandb.init( id=shared_id, diff --git a/tests/unit_tests/observability/conftest.py b/tests/unit_tests/observability/conftest.py new file mode 100644 index 000000000..a803c252d --- /dev/null +++ b/tests/unit_tests/observability/conftest.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared fixtures and mocks for observability unit tests.""" + +from unittest.mock import MagicMock, patch + +import pytest +from forge.observability.metrics import LoggerBackend, MetricCollector + + +class MockBackend(LoggerBackend): + """Mock backend for testing metrics logging without external dependencies.""" + + def __init__(self, logger_backend_config=None): + super().__init__(logger_backend_config or {}) + self.logged_metrics = [] + self.init_called = False + self.finish_called = False + self.metadata = {} + + async def init(self, role="local", primary_logger_metadata=None): + self.init_called = True + self.role = role + self.primary_logger_metadata = primary_logger_metadata or {} + + async def log(self, metrics, step): + self.logged_metrics.append((metrics, step)) + + async def finish(self): + self.finish_called = True + + def get_metadata_for_secondary_ranks(self): + return self.metadata + + +@pytest.fixture(autouse=True) +def clear_metric_collector_singletons(): + """Clear MetricCollector singletons before each test to avoid state leakage.""" + MetricCollector._instances.clear() + yield + MetricCollector._instances.clear() + + +@pytest.fixture(autouse=True) +def clean_metrics_environment(): + """Override the global mock_metrics_globally fixture to allow real metrics testing.""" + import os + + from forge.env_constants import FORGE_DISABLE_METRICS + + # Set default state for tests (metrics enabled) + if FORGE_DISABLE_METRICS in os.environ: + del os.environ[FORGE_DISABLE_METRICS] + + yield + + +@pytest.fixture +def mock_rank(): + """Mock current_rank function with configurable rank.""" + with patch("forge.observability.metrics.current_rank") as mock: + rank_obj = MagicMock() + rank_obj.rank = 0 + mock.return_value = rank_obj + yield mock + + +@pytest.fixture +def mock_actor_context(): + """Mock Monarch actor context for testing actor name generation.""" + with patch("forge.observability.metrics.context") as mock_context, patch( + "forge.observability.metrics.current_rank" + ) as mock_rank: + + # Setup mock context + ctx = MagicMock() + actor_instance = MagicMock() + actor_instance.actor_id = "_1rjutFUXQrEJ[0].TestActorConfigured[0]" + ctx.actor_instance = actor_instance + mock_context.return_value = ctx + + # Setup mock rank + rank_obj = MagicMock() + rank_obj.rank = 0 + mock_rank.return_value = rank_obj + + yield { + "context": mock_context, + "rank": mock_rank, + "expected_name": "TestActor_0XQr_r0", + } diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py new file mode 100644 index 000000000..3e864bdf7 --- /dev/null +++ b/tests/unit_tests/observability/test_metrics.py @@ -0,0 +1,231 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for core metrics functionality focusing on critical fixes in Diff 1.""" + +from unittest.mock import MagicMock, patch + +import pytest + +from forge.observability.metric_actors import get_or_create_metric_logger +from forge.observability.metrics import ( + ConsoleBackend, + get_logger_backend_class, + MeanAccumulator, + MetricCollector, + record_metric, + Reduce, + WandbBackend, +) + + +class TestCriticalFixes: + """Test critical production fixes from Diff 1.""" + + def test_uninitialized_push_logs_warning(self, mock_rank, caplog): + """Test MetricCollector.push() logs warning when uninitialized.""" + collector = MetricCollector() + + # Should not raise error, just log warning and return + collector.push("test", 1.0, Reduce.MEAN) + assert any( + "Metric logging backends" in record.message for record in caplog.records + ) + + @pytest.mark.asyncio + async def test_uninitialized_flush_logs_warning(self, mock_rank, caplog): + """Test MetricCollector.flush() logs warning when uninitialized.""" + collector = MetricCollector() + + # Should not raise error, just log warning and return empty dict + result = await collector.flush(step=1, return_state=True) + assert result == {} + assert any( + "Cannot flush collected metrics" in record.message + for record in caplog.records + ) + + @patch.dict("os.environ", {"FORGE_DISABLE_METRICS": "true"}) + @patch("forge.observability.metrics.MetricCollector") + def test_record_metric_disabled(self, mock_collector_class): + """Test record_metric is no-op when FORGE_DISABLE_METRICS=true.""" + record_metric("loss", 1.5, Reduce.MEAN) + mock_collector_class.assert_not_called() + + @patch.dict("os.environ", {"FORGE_DISABLE_METRICS": "false"}) + @patch("forge.observability.metrics.MetricCollector") + def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank): + """Test record_metric works when FORGE_DISABLE_METRICS=false.""" + mock_collector = MagicMock() + mock_collector_class.return_value = mock_collector + + record_metric("loss", 1.5, Reduce.MEAN) + mock_collector_class.assert_called_once() + mock_collector.push.assert_called_once() + + @patch("forge.observability.metrics.get_actor_name_with_rank") + def test_wandb_backend_creation(self, mock_actor_name): + """Test WandbBackend creation and basic setup without WandB dependency.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + config = { + "project": "test_project", + "group": "test_group", + "reduce_across_ranks": True, + } + backend = WandbBackend(config) + + assert backend.project == "test_project" + assert backend.group == "test_group" + assert backend.reduce_across_ranks is True + assert backend.share_run_id is False # default + + # Test metadata method + metadata = backend.get_metadata_for_secondary_ranks() + assert metadata == {} # Should be empty when no run + + @patch("forge.observability.metrics.get_actor_name_with_rank") + @pytest.mark.asyncio + async def test_console_backend(self, mock_actor_name): + """Test ConsoleBackend basic operations.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + backend = ConsoleBackend({}) + + await backend.init(role="local") + + # Test log - should not raise + await backend.log({"test": 1.0}, step=1) + + await backend.finish() # Should not raise + + +class TestBasicAccumulators: + """Test basic accumulator functionality.""" + + def test_mean_accumulator(self): + """Test MeanAccumulator operations.""" + acc = MeanAccumulator(Reduce.MEAN) + + # Test initial state + assert acc.get_value() == 0.0 + state = acc.get_state() + assert state["sum"] == 0.0 + assert state["count"] == 0 + + # Test append and get_value + acc.append(10.0) + acc.append(20.0) + assert acc.get_value() == 15.0 + + # Test state + state = acc.get_state() + assert state["sum"] == 30.0 + assert state["count"] == 2 + assert state["reduction_type"] == "mean" + + # Test reset + acc.reset() + assert acc.get_value() == 0.0 + assert acc.get_state()["sum"] == 0.0 + assert acc.get_state()["count"] == 0 + + def test_reduce_enum_accumulator_mapping(self): + """Test that Reduce enum correctly maps to accumulator classes.""" + assert Reduce.MEAN.accumulator_class == MeanAccumulator + + +class TestBackendFactory: + """Test backend factory function.""" + + def test_backend_factory(self): + """Test get_logger_backend_class factory function.""" + assert get_logger_backend_class("console") == ConsoleBackend + assert get_logger_backend_class("wandb") == WandbBackend + + with pytest.raises(ValueError, match="Unknown logger backend type"): + get_logger_backend_class("invalid_backend") + + +class TestMetricCollector: + """Test MetricCollector singleton behavior.""" + + def test_singleton_per_rank(self, mock_rank): + """Test MetricCollector singleton behavior per rank.""" + mock_rank.return_value.rank = 0 + collector1 = MetricCollector() + collector2 = MetricCollector() + assert collector1 is collector2 + + # Different rank should get different instance + mock_rank.return_value.rank = 1 + collector3 = MetricCollector() + assert collector1 is not collector3 + + +class TestMetricActorDisabling: + """Test environment flag to disable metric actors.""" + + async def _test_fetcher_registration(self, env_var_value, should_register_fetchers): + """Check if FORGE_DISABLE_METRICS=[True, False, None] correctly disables fetcher registration. + + Args: + env_var_value: Value to set for FORGE_DISABLE_METRICS (None means unset) + should_register_fetchers: Whether fetchers should be registered (True) or not (False) + """ + import os + + import forge.observability.metric_actors + from forge.env_constants import FORGE_DISABLE_METRICS + from monarch.actor import this_host + + # set fresh env + # Note: Environment variable setup is handled by clean_metrics_environment fixture + forge.observability.metric_actors._global_logger = None + + if env_var_value is not None: + os.environ[FORGE_DISABLE_METRICS] = env_var_value + + procs = this_host().spawn_procs(per_host={"cpus": 1}) + + if hasattr(procs, "_local_fetcher"): + delattr(procs, "_local_fetcher") + + # Test functionality + global_logger = await get_or_create_metric_logger(proc_mesh=procs) + + # Get results to check + proc_has_fetcher = hasattr(procs, "_local_fetcher") + global_has_fetcher = await global_logger.has_fetcher.call_one(procs) + + # Assert based on expected behavior + if should_register_fetchers: + assert ( + proc_has_fetcher + ), f"Expected process to have _local_fetcher when {env_var_value=}" + assert ( + global_has_fetcher + ), f"Expected global logger to have fetcher registered when {env_var_value=}" + else: + assert ( + not proc_has_fetcher + ), f"Expected process to NOT have _local_fetcher when {env_var_value=}" + assert ( + not global_has_fetcher + ), f"Expected global logger to NOT have fetcher registered when {env_var_value=}" + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "env_value,should_register", + [ + ("false", True), + ("true", False), + (None, True), + ], + ) + async def test_fetcher_registration_with_env_flag(self, env_value, should_register): + """Test fetcher registration behavior with different environment flag values.""" + await self._test_fetcher_registration(env_value, should_register) From feb4771a5e5b36bb44cc8c0bbbfde7c3311bcf57 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 09:26:47 -0700 Subject: [PATCH 2/7] commit --- src/forge/env_constants.py | 2 +- src/forge/observability/__init__.py | 10 +- src/forge/observability/metric_actors.py | 24 +- src/forge/observability/metrics.py | 121 +++++++--- src/forge/observability/perf_tracker.py | 4 +- .../unit_tests/observability/test_metrics.py | 208 +++++++++++++++++- .../observability/test_perf_tracker.py | 10 +- 7 files changed, 329 insertions(+), 50 deletions(-) diff --git a/src/forge/env_constants.py b/src/forge/env_constants.py index a4e024d83..6e0fc30e7 100644 --- a/src/forge/env_constants.py +++ b/src/forge/env_constants.py @@ -11,7 +11,7 @@ # Force all timing methods in forge.observability.perf_tracker.py to use # CPU timer if False or GPU timer if True. If unset, defaults to the assigned value to the function. -METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA" +METRIC_TIMER_USES_GPU = "METRIC_TIMER_USES_GPU" # Makes forge.observability.metrics.record_metric a no-op # and disables spawning LocalFetcherActor in get_or_create_metric_logger diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 52262eed5..f37dacebd 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -10,15 +10,15 @@ LocalFetcherActor, ) from .metrics import ( + BackendRole, ConsoleBackend, - # Utility functions get_actor_name_with_rank, get_logger_backend_class, - # Backend classes LoggerBackend, + LoggingMode, MaxAccumulator, MeanAccumulator, - # Accumulator classes + Metric, MetricAccumulator, MetricCollector, MinAccumulator, @@ -41,8 +41,12 @@ # Performance tracking "Tracer", "trace", + # Data classes + "Metric", + "BackendRole", # Enums "Reduce", + "LoggingMode", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index edd1f24d8..85dd56150 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -106,7 +106,7 @@ async def get_or_create_metric_logger( "local_fetcher_actor", LocalFetcherActor, global_logger ) await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) - proc._local_fetcher = local_fetcher_actor + proc._local_fetcher = local_fetcher_actor # pyre-ignore return global_logger @@ -125,13 +125,13 @@ def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None @endpoint async def flush( - self, step: int, return_state: bool = False + self, global_step: int, return_state: bool = False ) -> Dict[str, Dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. This should only ever be called by the global logger. Args: - step (int): train step used by backends to align all metrics on the same x-axis + global_step (int): step used by backends to align all metrics on the same x-axis return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: @@ -139,7 +139,7 @@ async def flush( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ collector = MetricCollector() - result = await collector.flush(step, return_state=return_state) + result = await collector.flush(global_step, return_state=return_state) return result @endpoint @@ -260,13 +260,13 @@ async def deregister_fetcher(self, name: str | ProcMesh): del self.fetchers[name] @endpoint - async def flush(self, step: int): + async def flush(self, global_step: int): """ Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors log to local backends and return states if needed for cross-rank reduction. Args: - step (int): Global step for logging. + global_step (int): step for logging. """ if not self.fetchers: return @@ -285,12 +285,14 @@ async def flush(self, step: int): for backend_config in config.values() ) - logger.debug(f"Global flush for step {step}: {len(self.fetchers)} fetchers") + logger.debug( + f"Global flush for global_step {global_step}: {len(self.fetchers)} fetchers" + ) # Broadcast flush to all fetchers results = await asyncio.gather( *[ - f.flush.call(step, return_state=requires_reduce) + f.flush.call(global_step, return_state=requires_reduce) for f in self.fetchers.values() ], return_exceptions=True, @@ -314,10 +316,10 @@ async def flush(self, step: int): ) if not all_local_states: - logger.warning(f"No states to reduce for step {step}") + logger.warning(f"No states to reduce for global_step {global_step}") return - # Reduce + # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) # Log to each global logger_backend @@ -325,7 +327,7 @@ async def flush(self, step: int): logger_backend_name, logger_backend, ) in self.global_logger_backends.items(): - await logger_backend.log(reduced_metrics, step) + await logger_backend.log(reduced_metrics, global_step) @endpoint def has_fetcher(self, name: str | ProcMesh) -> bool: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 4d527ec1b..a32e6d7a1 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -5,17 +5,56 @@ # LICENSE file in the root directory of this source tree. import logging - import os from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional +import pytz from monarch.actor import context, current_rank logger = logging.getLogger(__name__) +class BackendRole: + """Backend role constants for metric logging actors. + + Defines whether an actor operates as a local (per-rank) or global (controller) role + in the distributed metrics collection system. + """ + + LOCAL: str = "local" + GLOBAL: str = "global" + + +class LoggingMode(Enum): + """Metric logging behavior for distributed training scenarios. + + Each mode serves different observability needs: + + GLOBAL_REDUCE = "global_reduce" + Best for: Metrics that are best visualized as a single value per step. + Behavior: All ranks accumulate → controller reduces → single log entry + Example use: 8 ranks training, want 1 loss value per step averaged across all + + PER_RANK_REDUCE = "per_rank_reduce" + Best for: Per-rank performance metrics, debugging individual rank behavior + Behavior: Each rank accumulates + logs its own reduced values + Example use: Monitor GPU utilization per rank, get 8 separate log entries per step + + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + Best for: Real-time streaming, time-series debugging + Behavior: Raw values logged immediately on record_metric() calls + Example use: See what every rank is doing in real time. + """ + + GLOBAL_REDUCE = "global_reduce" + PER_RANK_REDUCE = "per_rank_reduce" + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + + class Reduce(Enum): MEAN = "mean" SUM = "sum" @@ -35,6 +74,24 @@ def accumulator_class(self): return mapping[self] +@dataclass +class Metric: + """Container for metric data including key, value, reduction type, and timestamp. + + Timestamp is automatically set to current EST time if not provided. + """ + + key: str + value: Any + reduction: Reduce + timestamp: Optional[float] = None + + def __post_init__(self): + if self.timestamp is None: + # Always record in UTC timezone + self.timestamp = datetime.now(pytz.UTC).timestamp() + + def get_actor_name_with_rank() -> str: """ Extracts actor information from Monarch context to form a logging name. @@ -109,8 +166,8 @@ def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None collector.push(key, value, reduction) -def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, Any]: - """Reduce metric accumulators states to a single value per metric. +def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metric]: + """Reduce metric accumulators states to a list of metrics. Can be used when reducing metrics across ranks or services, as merging states is more precise than merging locally reduced metrics. @@ -120,7 +177,7 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. Returns: - Dict[str, Any]: Dictionary with format {metric_key: reduced_value} + List[Metric]: List of reduced metrics Example: states = [ @@ -128,18 +185,18 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}}, ] reduce_metrics_states(states) - >>> {"loss": 2.0} + >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)] Raises: ValueError: on mismatched reduction types for the same metric key. """ if not states: - return {} + return [] # Collect unique keys across all all_keys = set(k for state in states for k in state) - reduced_metrics = {} + reduced_metrics = [] for key in all_keys: metric_states = [state.get(key) for state in states if key in state] if not metric_states: @@ -158,7 +215,14 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, metric_accumulator = Reduce(first_reduction_type).accumulator_class reduced_value = metric_accumulator.get_reduced_value_from_states(metric_states) - reduced_metrics[key] = reduced_value + + # Create Metric object with reduced value + metric = Metric( + key=key, + value=reduced_value, + reduction=Reduce(first_reduction_type), + ) + reduced_metrics.append(metric) return reduced_metrics @@ -459,12 +523,12 @@ def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: self.accumulators[key].append(value) async def flush( - self, step: int, return_state: bool = False + self, global_step: int, return_state: bool = False ) -> Dict[str, Dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. Args: - step (int): Step used by backends to align metrics on the same x-axis + global_step (int): step used by backends to align metrics on the same x-axis return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: @@ -487,7 +551,7 @@ async def flush( if not self.accumulators: logger.debug( - f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for step {step}" + f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for global_step {global_step}" ) return {} @@ -499,14 +563,12 @@ async def flush( # Reduce metrics from states for logging if any per-rank backend if self.logger_backends: - metrics = {} - for key, state in states.items(): - acc_class = Reduce(state["reduction_type"]).accumulator_class - metrics[key] = acc_class.get_reduced_value_from_states([state]) + # Use reduce_metrics_states for consistency + reduced_metrics = reduce_metrics_states([states]) # Log to local logger_backends for logger_backend in self.logger_backends: - await logger_backend.log(metrics, step) + await logger_backend.log(reduced_metrics, global_step) return states if return_state else {} @@ -554,7 +616,7 @@ async def init( primary_logger_metadata = {} pass - async def log(self, metrics: Dict[str, Any], step: int) -> None: + async def log(self, metrics: List[Metric], global_step: int) -> None: pass async def finish(self) -> None: @@ -582,11 +644,14 @@ async def init( else "GLOBAL" ) - async def log(self, metrics: Dict[str, Any], step: int) -> None: - logger.info(f"=== [{self.prefix}] - METRICS STEP {step} ===") - for key, value in sorted(metrics.items()): - logger.info(f" {key}: {value}") - logger.info("==============================\n") + async def log(self, metrics: List[Metric], global_step: int) -> None: + metrics_str = "\n".join( + f" {metric.key}: {metric.value}" + for metric in sorted(metrics, key=lambda m: m.key) + ) + logger.info( + f"=== [{self.prefix}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" + ) async def finish(self) -> None: pass @@ -701,11 +766,17 @@ async def _init_shared_local(self, primary_metadata: Dict[str, Any]): settings=settings, ) - async def log(self, metrics: Dict[str, Any], step: int) -> None: + async def log(self, metrics: List[Metric], global_step: int) -> None: if self.run: - log_data = {**metrics, "global_step": step} + # Convert metrics to WandB log format + log_data = {"global_step": global_step} + for metric in metrics: + log_data[metric.key] = metric.value + self.run.log(log_data) - logger.info(f"WandbBackend: Logged {len(metrics)} metrics at step {step}") + logger.info( + f"WandbBackend: Logged {len(metrics)} metrics at global_step {global_step}" + ) else: logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") diff --git a/src/forge/observability/perf_tracker.py b/src/forge/observability/perf_tracker.py index e85b81e26..47577d916 100644 --- a/src/forge/observability/perf_tracker.py +++ b/src/forge/observability/perf_tracker.py @@ -15,7 +15,7 @@ import torch -from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_CUDA +from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU from forge.observability.metrics import record_metric, Reduce # Thread-local memory tracking state @@ -125,7 +125,7 @@ def start(self) -> None: # Start timing (always enabled) time_with_gpu_events = ( - os.getenv(METRIC_TIMER_USES_CUDA, str(self.time_with_gpu)).lower() == "true" + os.getenv(METRIC_TIMER_USES_GPU, str(self.time_with_gpu)).lower() == "true" ) and torch.cuda.is_available() self._timer = _TimerCUDA() if time_with_gpu_events else _TimerCPU() self._timer.start() diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index 3e864bdf7..e65eb4f42 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -4,24 +4,224 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Unit tests for core metrics functionality focusing on critical fixes in Diff 1.""" +"""Unit tests for core metrics functionality.""" +import time from unittest.mock import MagicMock, patch import pytest from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import ( + BackendRole, ConsoleBackend, get_logger_backend_class, + LoggingMode, + MaxAccumulator, MeanAccumulator, + Metric, MetricCollector, + MinAccumulator, record_metric, Reduce, + reduce_metrics_states, + StdAccumulator, + SumAccumulator, WandbBackend, ) +class TestMetricCreation: + """Test Metric object creation and record_metric function - Diff 2 features.""" + + def test_metric_creation_automatic_timestamp(self, mock_rank): + """Test Metric object creation with automatic timestamp.""" + before_time = time.time() + metric = Metric("test_key", 42.0, Reduce.MEAN) + after_time = time.time() + + assert metric.key == "test_key" + assert metric.value == 42.0 + assert metric.reduction == Reduce.MEAN + assert metric.timestamp is not None + assert before_time <= metric.timestamp <= after_time + + def test_metric_creation_custom_timestamp(self, mock_rank): + """Test Metric object creation with custom timestamp.""" + custom_time = 1234567890.0 + metric = Metric("test_key2", 24.0, Reduce.SUM, timestamp=custom_time) + assert metric.timestamp == custom_time + + def test_record_metric(self, mock_rank): + """Test record_metric calls collector correctly.""" + # Mock the MetricCollector constructor to return a mock instance + mock_collector = MagicMock() + + with patch( + "forge.observability.metrics.MetricCollector", return_value=mock_collector + ): + record_metric("loss", 1.5, Reduce.MEAN) + + # Verify push was called on the mock collector with correct parameters + mock_collector.push.assert_called_once_with("loss", 1.5, Reduce.MEAN) + + def test_new_enums_and_constants(self): + """Test new LoggingMode enum and BackendRole constants.""" + # Test LoggingMode enum values + assert LoggingMode.GLOBAL_REDUCE.value == "global_reduce" + assert LoggingMode.PER_RANK_REDUCE.value == "per_rank_reduce" + assert LoggingMode.PER_RANK_NO_REDUCE.value == "per_rank_no_reduce" + + # Test BackendRole constants + assert BackendRole.LOCAL == "local" + assert BackendRole.GLOBAL == "global" + + +class TestReduceOperations: + """Test reduce_metrics_states function returning List[Metric] - Diff 2 feature.""" + + def test_empty_states(self): + """Test reduce_metrics_states with empty input.""" + result = reduce_metrics_states([]) + assert result == [] + + def test_single_state(self): + """Test reduce_metrics_states with single state.""" + states = [{"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}] + result = reduce_metrics_states(states) + assert len(result) == 1 + assert result[0].key == "loss" + assert result[0].value == 5.0 + assert result[0].reduction == Reduce.MEAN + + def test_multiple_states(self): + """Test reduce_metrics_states with multiple states.""" + states = [ + {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}, + {"loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}}, + {"accuracy": {"reduction_type": "sum", "total": 15.0}}, + ] + result = reduce_metrics_states(states) + + # Convert to dict for easier testing + result_dict = {metric.key: metric.value for metric in result} + assert result_dict["loss"] == 30.0 / 5.0 # 6.0 + assert result_dict["accuracy"] == 15.0 + + # Also check reduction types + for metric in result: + if metric.key == "loss": + assert metric.reduction == Reduce.MEAN + elif metric.key == "accuracy": + assert metric.reduction == Reduce.SUM + + def test_mismatched_reduction_types_raises_error(self): + """Test reduce_metrics_states raises error for mismatched reduction types.""" + states = [ + {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}, + {"loss": {"reduction_type": "sum", "total": 20.0}}, + ] + with pytest.raises(ValueError, match="Mismatched reduction types"): + reduce_metrics_states(states) + + +class TestAccumulators: + """Test all accumulator classes and their operations - Diff 2 extensions.""" + + def test_sum_accumulator(self): + """Test SumAccumulator operations.""" + acc = SumAccumulator(Reduce.SUM) + + acc.append(5.0) + acc.append(3.0) + assert acc.get_value() == 8.0 + + state = acc.get_state() + assert state["total"] == 8.0 + assert state["reduction_type"] == "sum" + + acc.reset() + assert acc.get_value() == 0.0 + + def test_max_accumulator(self): + """Test MaxAccumulator operations.""" + acc = MaxAccumulator(Reduce.MAX) + + acc.append(5.0) + acc.append(10.0) + acc.append(3.0) + assert acc.get_value() == 10.0 + + state = acc.get_state() + assert state["max_val"] == 10.0 + assert state["reduction_type"] == "max" + + def test_min_accumulator(self): + """Test MinAccumulator operations.""" + acc = MinAccumulator(Reduce.MIN) + + acc.append(5.0) + acc.append(10.0) + acc.append(3.0) + assert acc.get_value() == 3.0 + + state = acc.get_state() + assert state["min_val"] == 3.0 + assert state["reduction_type"] == "min" + + def test_std_accumulator(self): + """Test StdAccumulator operations.""" + acc = StdAccumulator(Reduce.STD) + + # Test with zero/one values + assert acc.get_value() == 0.0 + acc.append(5.0) + assert acc.get_value() == 0.0 # std of single value is 0 + + # Test with multiple values + acc.append(7.0) # values: 5, 7, mean=6, std=1 + assert abs(acc.get_value() - 1.0) < 0.001 + + state = acc.get_state() + assert state["sum"] == 12.0 + assert state["sum_sq"] == 74.0 # 5^2 + 7^2 = 25 + 49 = 74 + assert state["count"] == 2 + + @pytest.mark.parametrize( + "accumulator_class,states,expected", + [ + ( + MeanAccumulator, + [ + {"reduction_type": "mean", "sum": 10.0, "count": 2}, + {"reduction_type": "mean", "sum": 20.0, "count": 3}, + ], + 6.0, # (10+20) / (2+3) + ), + ( + SumAccumulator, + [ + {"reduction_type": "sum", "total": 10.0}, + {"reduction_type": "sum", "total": 15.0}, + ], + 25.0, + ), + ], + ) + def test_accumulator_state_reduction(self, accumulator_class, states, expected): + """Test cross-accumulator state reduction.""" + result = accumulator_class.get_reduced_value_from_states(states) + assert result == expected + + def test_reduce_enum_accumulator_mapping(self): + """Test that Reduce enum correctly maps to accumulator classes.""" + assert Reduce.MEAN.accumulator_class == MeanAccumulator + assert Reduce.SUM.accumulator_class == SumAccumulator + assert Reduce.MAX.accumulator_class == MaxAccumulator + assert Reduce.MIN.accumulator_class == MinAccumulator + assert Reduce.STD.accumulator_class == StdAccumulator + + class TestCriticalFixes: """Test critical production fixes from Diff 1.""" @@ -41,7 +241,7 @@ async def test_uninitialized_flush_logs_warning(self, mock_rank, caplog): collector = MetricCollector() # Should not raise error, just log warning and return empty dict - result = await collector.flush(step=1, return_state=True) + result = await collector.flush(global_step=1, return_state=True) assert result == {} assert any( "Cannot flush collected metrics" in record.message @@ -98,7 +298,9 @@ async def test_console_backend(self, mock_actor_name): await backend.init(role="local") # Test log - should not raise - await backend.log({"test": 1.0}, step=1) + # Create a test metric + test_metric = Metric("test", 1.0, Reduce.MEAN) + await backend.log([test_metric], global_step=1) await backend.finish() # Should not raise diff --git a/tests/unit_tests/observability/test_perf_tracker.py b/tests/unit_tests/observability/test_perf_tracker.py index 6af7331f1..01d1603d1 100644 --- a/tests/unit_tests/observability/test_perf_tracker.py +++ b/tests/unit_tests/observability/test_perf_tracker.py @@ -12,7 +12,7 @@ import pytest import torch -from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_CUDA +from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU from forge.observability.metrics import Reduce from forge.observability.perf_tracker import _TimerCPU, _TimerCUDA, trace, Tracer @@ -135,7 +135,7 @@ def test_comprehensive_workflow( if timer == "gpu" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - monkeypatch.setenv(METRIC_TIMER_USES_CUDA, str(timer == "gpu")) + monkeypatch.setenv(METRIC_TIMER_USES_GPU, str(timer == "gpu")) async def run_concurrent_tasks(): start_time = time.perf_counter() @@ -370,17 +370,17 @@ async def disabled_workflow(): ("false", _TimerCPU), ], ) - def test_metric_timer_uses_cuda_override( + def test_metric_timer_uses_gpu_override( self, env_value, expected_backend, monkeypatch ): - """Test METRIC_TIMER_USES_CUDA env var overrides timer parameter.""" + """Test METRIC_TIMER_USES_GPU env var overrides timer parameter.""" if env_value == "true" and not torch.cuda.is_available(): pytest.skip("CUDA not available") with patch("torch.cuda.is_available", return_value=True), patch( "forge.observability.perf_tracker.record_metric" ): - monkeypatch.setenv(METRIC_TIMER_USES_CUDA, env_value) + monkeypatch.setenv(METRIC_TIMER_USES_GPU, env_value) # Test with timer="cpu" (should be overridden by env) tracer = Tracer("env_test", timer="cpu") From 41ceaa48d076a121719807d80bfbd524ac771d8f Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:06:48 -0700 Subject: [PATCH 3/7] update backend role typehints and enum --- src/forge/observability/__init__.py | 2 - src/forge/observability/metric_actors.py | 20 ++--- src/forge/observability/metrics.py | 79 ++++++------------- .../unit_tests/observability/test_metrics.py | 42 +++++++--- 4 files changed, 68 insertions(+), 75 deletions(-) diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index f37dacebd..b970e57fa 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -15,7 +15,6 @@ get_actor_name_with_rank, get_logger_backend_class, LoggerBackend, - LoggingMode, MaxAccumulator, MeanAccumulator, Metric, @@ -46,7 +45,6 @@ "BackendRole", # Enums "Reduce", - "LoggingMode", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 85dd56150..0c4d15c34 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -13,6 +13,7 @@ from forge.env_constants import FORGE_DISABLE_METRICS from forge.observability.metrics import ( + BackendRole, get_logger_backend_class, LoggerBackend, MetricCollector, @@ -147,14 +148,13 @@ async def init_backends( self, metadata_per_primary_backend: Dict[str, Dict[str, Any]], config: Dict[str, Any], - ): + ) -> None: """Init local (per-rank) logger backends and MetricCollector.""" collector = MetricCollector() await collector.init_backends(metadata_per_primary_backend, config) @endpoint - async def shutdown(self): - + async def shutdown(self) -> None: collector = MetricCollector() await collector.shutdown() @@ -185,7 +185,7 @@ def __init__(self): self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {} @endpoint - async def init_backends(self, config: Dict[str, Any]): + async def init_backends(self, config: Dict[str, Any]) -> None: """ Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors in all registered fetchers. @@ -208,7 +208,7 @@ async def init_backends(self, config: Dict[str, Any]): for backend_name, backend_config in config.items(): backend = get_logger_backend_class(backend_name)(backend_config) - await backend.init(role="global") + await backend.init(role=BackendRole.GLOBAL) # Extract metadata from primary logger to be shared with secondary loggers # and store it @@ -236,7 +236,9 @@ async def init_backends(self, config: Dict[str, Any]): await asyncio.gather(*tasks, return_exceptions=True) @endpoint - async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMesh): + async def register_fetcher( + self, fetcher: LocalFetcherActor, name: str | ProcMesh + ) -> None: """Registers a fetcher with the global actor. Each key represents a process mesh. If there are 2 processes, each with 2 replicas with N gpus, we would have 4 keys, i.e. 2 proces meshes, each with 2 replicas.""" @@ -250,7 +252,7 @@ async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMes ) @endpoint - async def deregister_fetcher(self, name: str | ProcMesh): + async def deregister_fetcher(self, name: str | ProcMesh) -> None: if name not in self.fetchers: logger.warning( f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister." @@ -260,7 +262,7 @@ async def deregister_fetcher(self, name: str | ProcMesh): del self.fetchers[name] @endpoint - async def flush(self, global_step: int): + async def flush(self, global_step: int) -> None: """ Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors log to local backends and return states if needed for cross-rank reduction. @@ -339,7 +341,7 @@ def get_fetcher_count(self) -> int: return len(self.fetchers) @endpoint - async def shutdown(self): + async def shutdown(self) -> None: # Finish per-rank logger_backends via fetchers if self.fetchers: tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()] diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index a32e6d7a1..5ba396d7f 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -18,41 +18,15 @@ logger = logging.getLogger(__name__) -class BackendRole: +class BackendRole(Enum): """Backend role constants for metric logging actors. Defines whether an actor operates as a local (per-rank) or global (controller) role in the distributed metrics collection system. """ - LOCAL: str = "local" - GLOBAL: str = "global" - - -class LoggingMode(Enum): - """Metric logging behavior for distributed training scenarios. - - Each mode serves different observability needs: - - GLOBAL_REDUCE = "global_reduce" - Best for: Metrics that are best visualized as a single value per step. - Behavior: All ranks accumulate → controller reduces → single log entry - Example use: 8 ranks training, want 1 loss value per step averaged across all - - PER_RANK_REDUCE = "per_rank_reduce" - Best for: Per-rank performance metrics, debugging individual rank behavior - Behavior: Each rank accumulates + logs its own reduced values - Example use: Monitor GPU utilization per rank, get 8 separate log entries per step - - PER_RANK_NO_REDUCE = "per_rank_no_reduce" - Best for: Real-time streaming, time-series debugging - Behavior: Raw values logged immediately on record_metric() calls - Example use: See what every rank is doing in real time. - """ - - GLOBAL_REDUCE = "global_reduce" - PER_RANK_REDUCE = "per_rank_reduce" - PER_RANK_NO_REDUCE = "per_rank_no_reduce" + LOCAL = "local" + GLOBAL = "global" class Reduce(Enum): @@ -235,7 +209,7 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metri class MetricAccumulator(ABC): """Every metric maps to a MetricAccumulator, which accumulates values and optionally reduces them.""" - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: self.reduction_type = reduction @abstractmethod @@ -266,7 +240,7 @@ def reset(self) -> None: class MeanAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.sum = 0.0 self.count = 0 @@ -298,7 +272,7 @@ def reset(self) -> None: class SumAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.total = 0.0 @@ -321,7 +295,7 @@ def reset(self) -> None: class MaxAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.max_val = float("-inf") @@ -344,7 +318,7 @@ def reset(self) -> None: class MinAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.min_val = float("inf") @@ -367,7 +341,7 @@ def reset(self) -> None: class StdAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.sum = 0.0 self.sum_sq = 0.0 @@ -453,7 +427,7 @@ def __new__(cls): ) return inst - def __init__(self): + def __init__(self) -> None: if hasattr(self, "_is_initialized"): return @@ -493,7 +467,7 @@ async def init_backends( # instantiate local backend logger_backend = get_logger_backend_class(backend_name)(backend_config) await logger_backend.init( - role="local", primary_logger_metadata=primary_metadata + role=BackendRole.LOCAL, primary_logger_metadata=primary_metadata ) self.logger_backends.append(logger_backend) @@ -592,20 +566,20 @@ async def shutdown(self): class LoggerBackend(ABC): """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.""" - def __init__(self, logger_backend_config: Dict[str, Any]): + def __init__(self, logger_backend_config: Dict[str, Any]) -> None: self.logger_backend_config = logger_backend_config @abstractmethod async def init( self, - role: str, + role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). Args: - role (str): "global" (controller/primary) or "local" (per-rank/secondary). + role (BackendRole): BackendRole.GLOBAL (controller/primary) or BackendRole.LOCAL (per-rank/secondary). Can be used to behave differently for primary vs secondary roles. primary_logger_metadata (Optional[Dict[str, Any]]): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. @@ -630,18 +604,18 @@ def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: class ConsoleBackend(LoggerBackend): """Simple console logging of metrics.""" - def __init__(self, logger_backend_config: Dict[str, Any]): + def __init__(self, logger_backend_config: Dict[str, Any]) -> None: super().__init__(logger_backend_config) async def init( self, - role: str, + role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: self.prefix = ( get_actor_name_with_rank() if self.logger_backend_config.get("reduce_across_ranks", True) - else "GLOBAL" + else "Controller" ) async def log(self, metrics: List[Metric], global_step: int) -> None: @@ -675,7 +649,7 @@ class WandbBackend(LoggerBackend): group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" """ - def __init__(self, logger_backend_config: Dict[str, Any]): + def __init__(self, logger_backend_config: Dict[str, Any]) -> None: super().__init__(logger_backend_config) self.project = logger_backend_config["project"] self.group = logger_backend_config.get("group", "experiment_group") @@ -688,25 +662,22 @@ def __init__(self, logger_backend_config: Dict[str, Any]): async def init( self, - role: str, + role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: if primary_logger_metadata is None: primary_logger_metadata = {} - if role not in ["global", "local"]: - raise ValueError( - f"Invalid role {role} for WandbBackend init. Must be 'global' or 'local'." - ) - self.name = ( - get_actor_name_with_rank() if role == "local" else "global_controller" + get_actor_name_with_rank() + if role == BackendRole.LOCAL + else "global_controller" ) # Default global mode: only inits on controller if self.reduce_across_ranks: - if role != "global": + if role != BackendRole.GLOBAL: logger.debug( f"Skipped init for global mode (reduce_across_ranks=True) and {role} role." ) @@ -714,10 +685,10 @@ async def init( await self._init_global() # Per-rank modes based on share_run_id bool - elif role == "global" and self.share_run_id: + elif role == BackendRole.GLOBAL and self.share_run_id: await self._init_shared_global() - elif role == "local": + elif role == BackendRole.LOCAL: if self.share_run_id: await self._init_shared_local(primary_logger_metadata) else: diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index e65eb4f42..ee635c582 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -16,7 +16,6 @@ BackendRole, ConsoleBackend, get_logger_backend_class, - LoggingMode, MaxAccumulator, MeanAccumulator, Metric, @@ -66,15 +65,38 @@ def test_record_metric(self, mock_rank): mock_collector.push.assert_called_once_with("loss", 1.5, Reduce.MEAN) def test_new_enums_and_constants(self): - """Test new LoggingMode enum and BackendRole constants.""" - # Test LoggingMode enum values - assert LoggingMode.GLOBAL_REDUCE.value == "global_reduce" - assert LoggingMode.PER_RANK_REDUCE.value == "per_rank_reduce" - assert LoggingMode.PER_RANK_NO_REDUCE.value == "per_rank_no_reduce" - - # Test BackendRole constants - assert BackendRole.LOCAL == "local" - assert BackendRole.GLOBAL == "global" + """Test BackendRole constants and usage.""" + # Test BackendRole enum values + assert BackendRole.LOCAL.value == "local" + assert BackendRole.GLOBAL.value == "global" + + # Test that BackendRole is a proper Enum + assert isinstance(BackendRole.LOCAL, BackendRole) + assert isinstance(BackendRole.GLOBAL, BackendRole) + + @patch("forge.observability.metrics.get_actor_name_with_rank") + @pytest.mark.asyncio + async def test_backend_role_usage(self, mock_actor_name): + """Test that BackendRole constants are actually used instead of string literals.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + # Test ConsoleBackend + console_backend = ConsoleBackend({}) + await console_backend.init(role=BackendRole.LOCAL) + + # Test WandbBackend role validation without WandB initialization + wandb_backend = WandbBackend({"project": "test"}) + + # Mock all the WandB init methods to focus only on role validation + with patch.object(wandb_backend, "_init_global"), patch.object( + wandb_backend, "_init_shared_global" + ), patch.object(wandb_backend, "_init_shared_local"), patch.object( + wandb_backend, "_init_per_rank" + ): + + # Should not raise error for valid roles (type system prevents invalid values) + await wandb_backend.init(role=BackendRole.GLOBAL) + await wandb_backend.init(role=BackendRole.LOCAL) class TestReduceOperations: From 8a24e715c15bb3d18337a75fc491a8a81e605291 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:17:08 -0700 Subject: [PATCH 4/7] update where we check FORGE_DISABLE_METRICS --- src/forge/controller/provisioner.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 429c5760f..cf712079b 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -20,6 +20,7 @@ from forge.controller.launcher import BaseLauncher, get_launcher +from forge.env_constants import FORGE_DISABLE_METRICS from forge.observability.metric_actors import get_or_create_metric_logger from forge.types import ProcessConfig, ProvisionerConfig @@ -263,8 +264,8 @@ def bootstrap(env: dict[str, str]): self._proc_host_map[procs] = host_mesh # Spawn local fetcher actor on each process and register with global logger - # Can be disabled by FORGE_DISABLE_METRICS env var - _ = await get_or_create_metric_logger(procs) + if os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true": + _ = await get_or_create_metric_logger(procs) return procs async def host_mesh_from_proc(self, proc_mesh: ProcMesh): From 3f3bc51bd69316cd403c874a72b7e9824ae9f190 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:17:48 -0700 Subject: [PATCH 5/7] remove protected import --- src/forge/observability/metrics.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 4d527ec1b..64843f110 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -13,6 +13,8 @@ from monarch.actor import context, current_rank +from forge.util.logging import log_once + logger = logging.getLogger(__name__) @@ -437,8 +439,6 @@ async def init_backends( def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: if not self._is_initialized: - from forge.util.logging import log_once - log_once( logger, level=logging.WARNING, @@ -472,8 +472,6 @@ async def flush( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: - from forge.util.logging import log_once - log_once( logger, level=logging.WARNING, From 4fe26116d9562826f0fcc4cc37bbce48c40ccb18 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 10:23:18 -0700 Subject: [PATCH 6/7] protect import --- src/forge/controller/provisioner.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index cf712079b..5ca331f32 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -21,7 +21,6 @@ from forge.controller.launcher import BaseLauncher, get_launcher from forge.env_constants import FORGE_DISABLE_METRICS -from forge.observability.metric_actors import get_or_create_metric_logger from forge.types import ProcessConfig, ProvisionerConfig @@ -265,6 +264,8 @@ def bootstrap(env: dict[str, str]): # Spawn local fetcher actor on each process and register with global logger if os.getenv(FORGE_DISABLE_METRICS, "false").lower() != "true": + from forge.observability.metric_actors import get_or_create_metric_logger + _ = await get_or_create_metric_logger(procs) return procs @@ -286,6 +287,10 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): async with self._lock: # Deregister local logger from global logger if hasattr(proc_mesh, "_local_fetcher"): + from forge.observability.metric_actors import ( + get_or_create_metric_logger, + ) + global_logger = await get_or_create_metric_logger(proc_mesh) await global_logger.deregister_fetcher.call_one(proc_mesh) From d81a4edd0b4677a03d5c2484bdc2923fe2fef0e1 Mon Sep 17 00:00:00 2001 From: Felipe Mello Date: Wed, 8 Oct 2025 11:22:37 -0700 Subject: [PATCH 7/7] record_metric uses dataclass Metric --- src/forge/observability/metrics.py | 47 +++++++++++++++---- .../unit_tests/observability/test_metrics.py | 17 +++++-- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index f08d8d637..3c5386af9 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -118,8 +118,7 @@ def get_actor_name_with_rank() -> str: def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: - """ - Records a metric value for later reduction and logging. + """Thin wrapper to send metrics to per-rank local MetricCollectors. Relies on a per-rank MetricCollector singleton for ease of use, i.e. call `record_metric` anywhere in the code without moving the @@ -134,12 +133,14 @@ def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None Can be disabled globally by setting the environment variable `FORGE_DISABLE_METRICS=true`. """ - # Skip metrics collection if disabled for tests + # Skip metrics collection if os.getenv("FORGE_DISABLE_METRICS", "false").lower() == "true": return + # timestamp is added automatically by the Metric class + metric = Metric(key=key, value=value, reduction=reduction) collector = MetricCollector() - collector.push(key, value, reduction) + collector.push(metric) def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metric]: @@ -475,7 +476,20 @@ async def init_backends( self._is_initialized = True - def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: + def push(self, metric: Metric) -> None: + """Process a metric according to configured logging modes. + + Args: + metric: Metric dataclass containing key, value, reduction type, and timestamp. + + Raises: + TypeError: If metric is not a Metric object. + + Example: + collector = MetricCollector() + metric = Metric("loss", 0.5, Reduce.MEAN) + collector.push(metric) + """ if not self._is_initialized: log_once( logger, @@ -491,10 +505,17 @@ def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: ) return - if key not in self.accumulators: - self.accumulators[key] = reduction.accumulator_class(reduction) + # Validate metric object + if not isinstance(metric, Metric): + raise TypeError(f"Expected {Metric} object, got {type(metric)}") - self.accumulators[key].append(value) + # Always accumulate for reduction and state return + key = metric.key + if key not in self.accumulators: + self.accumulators[key] = metric.reduction.accumulator_class( + metric.reduction + ) + self.accumulators[key].append(metric.value) async def flush( self, global_step: int, return_state: bool = False @@ -584,11 +605,17 @@ async def init( Raises: ValueError if missing metadata for shared local init. """ - if primary_logger_metadata is None: - primary_logger_metadata = {} pass + @abstractmethod async def log(self, metrics: List[Metric], global_step: int) -> None: + """ + Log a batch of metrics to the backend. + + Args: + metrics: List of Metric objects to log. + global_step: Step number for x-axis alignment across metrics. + """ pass async def finish(self) -> None: diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index ee635c582..563f52e6c 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -52,7 +52,7 @@ def test_metric_creation_custom_timestamp(self, mock_rank): assert metric.timestamp == custom_time def test_record_metric(self, mock_rank): - """Test record_metric calls collector correctly.""" + """Test record_metric creates correct Metric and calls collector.""" # Mock the MetricCollector constructor to return a mock instance mock_collector = MagicMock() @@ -61,8 +61,14 @@ def test_record_metric(self, mock_rank): ): record_metric("loss", 1.5, Reduce.MEAN) - # Verify push was called on the mock collector with correct parameters - mock_collector.push.assert_called_once_with("loss", 1.5, Reduce.MEAN) + # Verify push was called on the mock collector + mock_collector.push.assert_called_once() + + # Verify the metric passed to push + pushed_metric = mock_collector.push.call_args[0][0] + assert pushed_metric.key == "loss" + assert pushed_metric.value == 1.5 + assert pushed_metric.reduction == Reduce.MEAN def test_new_enums_and_constants(self): """Test BackendRole constants and usage.""" @@ -250,9 +256,10 @@ class TestCriticalFixes: def test_uninitialized_push_logs_warning(self, mock_rank, caplog): """Test MetricCollector.push() logs warning when uninitialized.""" collector = MetricCollector() + metric = Metric("test", 1.0, Reduce.MEAN) # Should not raise error, just log warning and return - collector.push("test", 1.0, Reduce.MEAN) + collector.push(metric) assert any( "Metric logging backends" in record.message for record in caplog.records ) @@ -317,7 +324,7 @@ async def test_console_backend(self, mock_actor_name): backend = ConsoleBackend({}) - await backend.init(role="local") + await backend.init(role=BackendRole.LOCAL) # Test log - should not raise # Create a test metric