diff --git a/src/forge/env_constants.py b/src/forge/env_constants.py index a4e024d83..6e0fc30e7 100644 --- a/src/forge/env_constants.py +++ b/src/forge/env_constants.py @@ -11,7 +11,7 @@ # Force all timing methods in forge.observability.perf_tracker.py to use # CPU timer if False or GPU timer if True. If unset, defaults to the assigned value to the function. -METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA" +METRIC_TIMER_USES_GPU = "METRIC_TIMER_USES_GPU" # Makes forge.observability.metrics.record_metric a no-op # and disables spawning LocalFetcherActor in get_or_create_metric_logger diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 52262eed5..b970e57fa 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -10,15 +10,14 @@ LocalFetcherActor, ) from .metrics import ( + BackendRole, ConsoleBackend, - # Utility functions get_actor_name_with_rank, get_logger_backend_class, - # Backend classes LoggerBackend, MaxAccumulator, MeanAccumulator, - # Accumulator classes + Metric, MetricAccumulator, MetricCollector, MinAccumulator, @@ -41,6 +40,9 @@ # Performance tracking "Tracer", "trace", + # Data classes + "Metric", + "BackendRole", # Enums "Reduce", # Actor classes diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index edd1f24d8..0c4d15c34 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -13,6 +13,7 @@ from forge.env_constants import FORGE_DISABLE_METRICS from forge.observability.metrics import ( + BackendRole, get_logger_backend_class, LoggerBackend, MetricCollector, @@ -106,7 +107,7 @@ async def get_or_create_metric_logger( "local_fetcher_actor", LocalFetcherActor, global_logger ) await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) - proc._local_fetcher = local_fetcher_actor + proc._local_fetcher = local_fetcher_actor # pyre-ignore return global_logger @@ -125,13 +126,13 @@ def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None @endpoint async def flush( - self, step: int, return_state: bool = False + self, global_step: int, return_state: bool = False ) -> Dict[str, Dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. This should only ever be called by the global logger. Args: - step (int): train step used by backends to align all metrics on the same x-axis + global_step (int): step used by backends to align all metrics on the same x-axis return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: @@ -139,7 +140,7 @@ async def flush( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ collector = MetricCollector() - result = await collector.flush(step, return_state=return_state) + result = await collector.flush(global_step, return_state=return_state) return result @endpoint @@ -147,14 +148,13 @@ async def init_backends( self, metadata_per_primary_backend: Dict[str, Dict[str, Any]], config: Dict[str, Any], - ): + ) -> None: """Init local (per-rank) logger backends and MetricCollector.""" collector = MetricCollector() await collector.init_backends(metadata_per_primary_backend, config) @endpoint - async def shutdown(self): - + async def shutdown(self) -> None: collector = MetricCollector() await collector.shutdown() @@ -185,7 +185,7 @@ def __init__(self): self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {} @endpoint - async def init_backends(self, config: Dict[str, Any]): + async def init_backends(self, config: Dict[str, Any]) -> None: """ Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors in all registered fetchers. @@ -208,7 +208,7 @@ async def init_backends(self, config: Dict[str, Any]): for backend_name, backend_config in config.items(): backend = get_logger_backend_class(backend_name)(backend_config) - await backend.init(role="global") + await backend.init(role=BackendRole.GLOBAL) # Extract metadata from primary logger to be shared with secondary loggers # and store it @@ -236,7 +236,9 @@ async def init_backends(self, config: Dict[str, Any]): await asyncio.gather(*tasks, return_exceptions=True) @endpoint - async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMesh): + async def register_fetcher( + self, fetcher: LocalFetcherActor, name: str | ProcMesh + ) -> None: """Registers a fetcher with the global actor. Each key represents a process mesh. If there are 2 processes, each with 2 replicas with N gpus, we would have 4 keys, i.e. 2 proces meshes, each with 2 replicas.""" @@ -250,7 +252,7 @@ async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMes ) @endpoint - async def deregister_fetcher(self, name: str | ProcMesh): + async def deregister_fetcher(self, name: str | ProcMesh) -> None: if name not in self.fetchers: logger.warning( f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister." @@ -260,13 +262,13 @@ async def deregister_fetcher(self, name: str | ProcMesh): del self.fetchers[name] @endpoint - async def flush(self, step: int): + async def flush(self, global_step: int) -> None: """ Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors log to local backends and return states if needed for cross-rank reduction. Args: - step (int): Global step for logging. + global_step (int): step for logging. """ if not self.fetchers: return @@ -285,12 +287,14 @@ async def flush(self, step: int): for backend_config in config.values() ) - logger.debug(f"Global flush for step {step}: {len(self.fetchers)} fetchers") + logger.debug( + f"Global flush for global_step {global_step}: {len(self.fetchers)} fetchers" + ) # Broadcast flush to all fetchers results = await asyncio.gather( *[ - f.flush.call(step, return_state=requires_reduce) + f.flush.call(global_step, return_state=requires_reduce) for f in self.fetchers.values() ], return_exceptions=True, @@ -314,10 +318,10 @@ async def flush(self, step: int): ) if not all_local_states: - logger.warning(f"No states to reduce for step {step}") + logger.warning(f"No states to reduce for global_step {global_step}") return - # Reduce + # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) # Log to each global logger_backend @@ -325,7 +329,7 @@ async def flush(self, step: int): logger_backend_name, logger_backend, ) in self.global_logger_backends.items(): - await logger_backend.log(reduced_metrics, step) + await logger_backend.log(reduced_metrics, global_step) @endpoint def has_fetcher(self, name: str | ProcMesh) -> bool: @@ -337,7 +341,7 @@ def get_fetcher_count(self) -> int: return len(self.fetchers) @endpoint - async def shutdown(self): + async def shutdown(self) -> None: # Finish per-rank logger_backends via fetchers if self.fetchers: tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()] diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 64843f110..3c5386af9 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -5,12 +5,14 @@ # LICENSE file in the root directory of this source tree. import logging - import os from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional +import pytz from monarch.actor import context, current_rank from forge.util.logging import log_once @@ -18,6 +20,17 @@ logger = logging.getLogger(__name__) +class BackendRole(Enum): + """Backend role constants for metric logging actors. + + Defines whether an actor operates as a local (per-rank) or global (controller) role + in the distributed metrics collection system. + """ + + LOCAL = "local" + GLOBAL = "global" + + class Reduce(Enum): MEAN = "mean" SUM = "sum" @@ -37,6 +50,24 @@ def accumulator_class(self): return mapping[self] +@dataclass +class Metric: + """Container for metric data including key, value, reduction type, and timestamp. + + Timestamp is automatically set to current EST time if not provided. + """ + + key: str + value: Any + reduction: Reduce + timestamp: Optional[float] = None + + def __post_init__(self): + if self.timestamp is None: + # Always record in UTC timezone + self.timestamp = datetime.now(pytz.UTC).timestamp() + + def get_actor_name_with_rank() -> str: """ Extracts actor information from Monarch context to form a logging name. @@ -87,8 +118,7 @@ def get_actor_name_with_rank() -> str: def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: - """ - Records a metric value for later reduction and logging. + """Thin wrapper to send metrics to per-rank local MetricCollectors. Relies on a per-rank MetricCollector singleton for ease of use, i.e. call `record_metric` anywhere in the code without moving the @@ -103,16 +133,18 @@ def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None Can be disabled globally by setting the environment variable `FORGE_DISABLE_METRICS=true`. """ - # Skip metrics collection if disabled for tests + # Skip metrics collection if os.getenv("FORGE_DISABLE_METRICS", "false").lower() == "true": return + # timestamp is added automatically by the Metric class + metric = Metric(key=key, value=value, reduction=reduction) collector = MetricCollector() - collector.push(key, value, reduction) + collector.push(metric) -def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, Any]: - """Reduce metric accumulators states to a single value per metric. +def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metric]: + """Reduce metric accumulators states to a list of metrics. Can be used when reducing metrics across ranks or services, as merging states is more precise than merging locally reduced metrics. @@ -122,7 +154,7 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. Returns: - Dict[str, Any]: Dictionary with format {metric_key: reduced_value} + List[Metric]: List of reduced metrics Example: states = [ @@ -130,18 +162,18 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}}, ] reduce_metrics_states(states) - >>> {"loss": 2.0} + >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)] Raises: ValueError: on mismatched reduction types for the same metric key. """ if not states: - return {} + return [] # Collect unique keys across all all_keys = set(k for state in states for k in state) - reduced_metrics = {} + reduced_metrics = [] for key in all_keys: metric_states = [state.get(key) for state in states if key in state] if not metric_states: @@ -160,7 +192,14 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, metric_accumulator = Reduce(first_reduction_type).accumulator_class reduced_value = metric_accumulator.get_reduced_value_from_states(metric_states) - reduced_metrics[key] = reduced_value + + # Create Metric object with reduced value + metric = Metric( + key=key, + value=reduced_value, + reduction=Reduce(first_reduction_type), + ) + reduced_metrics.append(metric) return reduced_metrics @@ -173,7 +212,7 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, class MetricAccumulator(ABC): """Every metric maps to a MetricAccumulator, which accumulates values and optionally reduces them.""" - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: self.reduction_type = reduction @abstractmethod @@ -204,7 +243,7 @@ def reset(self) -> None: class MeanAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.sum = 0.0 self.count = 0 @@ -236,7 +275,7 @@ def reset(self) -> None: class SumAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.total = 0.0 @@ -259,7 +298,7 @@ def reset(self) -> None: class MaxAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.max_val = float("-inf") @@ -282,7 +321,7 @@ def reset(self) -> None: class MinAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.min_val = float("inf") @@ -305,7 +344,7 @@ def reset(self) -> None: class StdAccumulator(MetricAccumulator): - def __init__(self, reduction: Reduce): + def __init__(self, reduction: Reduce) -> None: super().__init__(reduction) self.sum = 0.0 self.sum_sq = 0.0 @@ -391,7 +430,7 @@ def __new__(cls): ) return inst - def __init__(self): + def __init__(self) -> None: if hasattr(self, "_is_initialized"): return @@ -431,13 +470,26 @@ async def init_backends( # instantiate local backend logger_backend = get_logger_backend_class(backend_name)(backend_config) await logger_backend.init( - role="local", primary_logger_metadata=primary_metadata + role=BackendRole.LOCAL, primary_logger_metadata=primary_metadata ) self.logger_backends.append(logger_backend) self._is_initialized = True - def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: + def push(self, metric: Metric) -> None: + """Process a metric according to configured logging modes. + + Args: + metric: Metric dataclass containing key, value, reduction type, and timestamp. + + Raises: + TypeError: If metric is not a Metric object. + + Example: + collector = MetricCollector() + metric = Metric("loss", 0.5, Reduce.MEAN) + collector.push(metric) + """ if not self._is_initialized: log_once( logger, @@ -453,18 +505,25 @@ def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: ) return - if key not in self.accumulators: - self.accumulators[key] = reduction.accumulator_class(reduction) + # Validate metric object + if not isinstance(metric, Metric): + raise TypeError(f"Expected {Metric} object, got {type(metric)}") - self.accumulators[key].append(value) + # Always accumulate for reduction and state return + key = metric.key + if key not in self.accumulators: + self.accumulators[key] = metric.reduction.accumulator_class( + metric.reduction + ) + self.accumulators[key].append(metric.value) async def flush( - self, step: int, return_state: bool = False + self, global_step: int, return_state: bool = False ) -> Dict[str, Dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. Args: - step (int): Step used by backends to align metrics on the same x-axis + global_step (int): step used by backends to align metrics on the same x-axis return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: @@ -485,7 +544,7 @@ async def flush( if not self.accumulators: logger.debug( - f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for step {step}" + f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for global_step {global_step}" ) return {} @@ -497,14 +556,12 @@ async def flush( # Reduce metrics from states for logging if any per-rank backend if self.logger_backends: - metrics = {} - for key, state in states.items(): - acc_class = Reduce(state["reduction_type"]).accumulator_class - metrics[key] = acc_class.get_reduced_value_from_states([state]) + # Use reduce_metrics_states for consistency + reduced_metrics = reduce_metrics_states([states]) # Log to local logger_backends for logger_backend in self.logger_backends: - await logger_backend.log(metrics, step) + await logger_backend.log(reduced_metrics, global_step) return states if return_state else {} @@ -528,31 +585,37 @@ async def shutdown(self): class LoggerBackend(ABC): """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.""" - def __init__(self, logger_backend_config: Dict[str, Any]): + def __init__(self, logger_backend_config: Dict[str, Any]) -> None: self.logger_backend_config = logger_backend_config @abstractmethod async def init( self, - role: str, + role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). Args: - role (str): "global" (controller/primary) or "local" (per-rank/secondary). + role (BackendRole): BackendRole.GLOBAL (controller/primary) or BackendRole.LOCAL (per-rank/secondary). Can be used to behave differently for primary vs secondary roles. primary_logger_metadata (Optional[Dict[str, Any]]): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. Raises: ValueError if missing metadata for shared local init. """ - if primary_logger_metadata is None: - primary_logger_metadata = {} pass - async def log(self, metrics: Dict[str, Any], step: int) -> None: + @abstractmethod + async def log(self, metrics: List[Metric], global_step: int) -> None: + """ + Log a batch of metrics to the backend. + + Args: + metrics: List of Metric objects to log. + global_step: Step number for x-axis alignment across metrics. + """ pass async def finish(self) -> None: @@ -566,25 +629,28 @@ def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: class ConsoleBackend(LoggerBackend): """Simple console logging of metrics.""" - def __init__(self, logger_backend_config: Dict[str, Any]): + def __init__(self, logger_backend_config: Dict[str, Any]) -> None: super().__init__(logger_backend_config) async def init( self, - role: str, + role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: self.prefix = ( get_actor_name_with_rank() if self.logger_backend_config.get("reduce_across_ranks", True) - else "GLOBAL" + else "Controller" ) - async def log(self, metrics: Dict[str, Any], step: int) -> None: - logger.info(f"=== [{self.prefix}] - METRICS STEP {step} ===") - for key, value in sorted(metrics.items()): - logger.info(f" {key}: {value}") - logger.info("==============================\n") + async def log(self, metrics: List[Metric], global_step: int) -> None: + metrics_str = "\n".join( + f" {metric.key}: {metric.value}" + for metric in sorted(metrics, key=lambda m: m.key) + ) + logger.info( + f"=== [{self.prefix}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" + ) async def finish(self) -> None: pass @@ -608,7 +674,7 @@ class WandbBackend(LoggerBackend): group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" """ - def __init__(self, logger_backend_config: Dict[str, Any]): + def __init__(self, logger_backend_config: Dict[str, Any]) -> None: super().__init__(logger_backend_config) self.project = logger_backend_config["project"] self.group = logger_backend_config.get("group", "experiment_group") @@ -621,25 +687,22 @@ def __init__(self, logger_backend_config: Dict[str, Any]): async def init( self, - role: str, + role: BackendRole, primary_logger_metadata: Optional[Dict[str, Any]] = None, ) -> None: if primary_logger_metadata is None: primary_logger_metadata = {} - if role not in ["global", "local"]: - raise ValueError( - f"Invalid role {role} for WandbBackend init. Must be 'global' or 'local'." - ) - self.name = ( - get_actor_name_with_rank() if role == "local" else "global_controller" + get_actor_name_with_rank() + if role == BackendRole.LOCAL + else "global_controller" ) # Default global mode: only inits on controller if self.reduce_across_ranks: - if role != "global": + if role != BackendRole.GLOBAL: logger.debug( f"Skipped init for global mode (reduce_across_ranks=True) and {role} role." ) @@ -647,10 +710,10 @@ async def init( await self._init_global() # Per-rank modes based on share_run_id bool - elif role == "global" and self.share_run_id: + elif role == BackendRole.GLOBAL and self.share_run_id: await self._init_shared_global() - elif role == "local": + elif role == BackendRole.LOCAL: if self.share_run_id: await self._init_shared_local(primary_logger_metadata) else: @@ -699,11 +762,17 @@ async def _init_shared_local(self, primary_metadata: Dict[str, Any]): settings=settings, ) - async def log(self, metrics: Dict[str, Any], step: int) -> None: + async def log(self, metrics: List[Metric], global_step: int) -> None: if self.run: - log_data = {**metrics, "global_step": step} + # Convert metrics to WandB log format + log_data = {"global_step": global_step} + for metric in metrics: + log_data[metric.key] = metric.value + self.run.log(log_data) - logger.info(f"WandbBackend: Logged {len(metrics)} metrics at step {step}") + logger.info( + f"WandbBackend: Logged {len(metrics)} metrics at global_step {global_step}" + ) else: logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") diff --git a/src/forge/observability/perf_tracker.py b/src/forge/observability/perf_tracker.py index e85b81e26..47577d916 100644 --- a/src/forge/observability/perf_tracker.py +++ b/src/forge/observability/perf_tracker.py @@ -15,7 +15,7 @@ import torch -from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_CUDA +from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU from forge.observability.metrics import record_metric, Reduce # Thread-local memory tracking state @@ -125,7 +125,7 @@ def start(self) -> None: # Start timing (always enabled) time_with_gpu_events = ( - os.getenv(METRIC_TIMER_USES_CUDA, str(self.time_with_gpu)).lower() == "true" + os.getenv(METRIC_TIMER_USES_GPU, str(self.time_with_gpu)).lower() == "true" ) and torch.cuda.is_available() self._timer = _TimerCUDA() if time_with_gpu_events else _TimerCPU() self._timer.start() diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index 3e864bdf7..563f52e6c 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -4,33 +4,262 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""Unit tests for core metrics functionality focusing on critical fixes in Diff 1.""" +"""Unit tests for core metrics functionality.""" +import time from unittest.mock import MagicMock, patch import pytest from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import ( + BackendRole, ConsoleBackend, get_logger_backend_class, + MaxAccumulator, MeanAccumulator, + Metric, MetricCollector, + MinAccumulator, record_metric, Reduce, + reduce_metrics_states, + StdAccumulator, + SumAccumulator, WandbBackend, ) +class TestMetricCreation: + """Test Metric object creation and record_metric function - Diff 2 features.""" + + def test_metric_creation_automatic_timestamp(self, mock_rank): + """Test Metric object creation with automatic timestamp.""" + before_time = time.time() + metric = Metric("test_key", 42.0, Reduce.MEAN) + after_time = time.time() + + assert metric.key == "test_key" + assert metric.value == 42.0 + assert metric.reduction == Reduce.MEAN + assert metric.timestamp is not None + assert before_time <= metric.timestamp <= after_time + + def test_metric_creation_custom_timestamp(self, mock_rank): + """Test Metric object creation with custom timestamp.""" + custom_time = 1234567890.0 + metric = Metric("test_key2", 24.0, Reduce.SUM, timestamp=custom_time) + assert metric.timestamp == custom_time + + def test_record_metric(self, mock_rank): + """Test record_metric creates correct Metric and calls collector.""" + # Mock the MetricCollector constructor to return a mock instance + mock_collector = MagicMock() + + with patch( + "forge.observability.metrics.MetricCollector", return_value=mock_collector + ): + record_metric("loss", 1.5, Reduce.MEAN) + + # Verify push was called on the mock collector + mock_collector.push.assert_called_once() + + # Verify the metric passed to push + pushed_metric = mock_collector.push.call_args[0][0] + assert pushed_metric.key == "loss" + assert pushed_metric.value == 1.5 + assert pushed_metric.reduction == Reduce.MEAN + + def test_new_enums_and_constants(self): + """Test BackendRole constants and usage.""" + # Test BackendRole enum values + assert BackendRole.LOCAL.value == "local" + assert BackendRole.GLOBAL.value == "global" + + # Test that BackendRole is a proper Enum + assert isinstance(BackendRole.LOCAL, BackendRole) + assert isinstance(BackendRole.GLOBAL, BackendRole) + + @patch("forge.observability.metrics.get_actor_name_with_rank") + @pytest.mark.asyncio + async def test_backend_role_usage(self, mock_actor_name): + """Test that BackendRole constants are actually used instead of string literals.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + # Test ConsoleBackend + console_backend = ConsoleBackend({}) + await console_backend.init(role=BackendRole.LOCAL) + + # Test WandbBackend role validation without WandB initialization + wandb_backend = WandbBackend({"project": "test"}) + + # Mock all the WandB init methods to focus only on role validation + with patch.object(wandb_backend, "_init_global"), patch.object( + wandb_backend, "_init_shared_global" + ), patch.object(wandb_backend, "_init_shared_local"), patch.object( + wandb_backend, "_init_per_rank" + ): + + # Should not raise error for valid roles (type system prevents invalid values) + await wandb_backend.init(role=BackendRole.GLOBAL) + await wandb_backend.init(role=BackendRole.LOCAL) + + +class TestReduceOperations: + """Test reduce_metrics_states function returning List[Metric] - Diff 2 feature.""" + + def test_empty_states(self): + """Test reduce_metrics_states with empty input.""" + result = reduce_metrics_states([]) + assert result == [] + + def test_single_state(self): + """Test reduce_metrics_states with single state.""" + states = [{"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}] + result = reduce_metrics_states(states) + assert len(result) == 1 + assert result[0].key == "loss" + assert result[0].value == 5.0 + assert result[0].reduction == Reduce.MEAN + + def test_multiple_states(self): + """Test reduce_metrics_states with multiple states.""" + states = [ + {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}, + {"loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}}, + {"accuracy": {"reduction_type": "sum", "total": 15.0}}, + ] + result = reduce_metrics_states(states) + + # Convert to dict for easier testing + result_dict = {metric.key: metric.value for metric in result} + assert result_dict["loss"] == 30.0 / 5.0 # 6.0 + assert result_dict["accuracy"] == 15.0 + + # Also check reduction types + for metric in result: + if metric.key == "loss": + assert metric.reduction == Reduce.MEAN + elif metric.key == "accuracy": + assert metric.reduction == Reduce.SUM + + def test_mismatched_reduction_types_raises_error(self): + """Test reduce_metrics_states raises error for mismatched reduction types.""" + states = [ + {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}, + {"loss": {"reduction_type": "sum", "total": 20.0}}, + ] + with pytest.raises(ValueError, match="Mismatched reduction types"): + reduce_metrics_states(states) + + +class TestAccumulators: + """Test all accumulator classes and their operations - Diff 2 extensions.""" + + def test_sum_accumulator(self): + """Test SumAccumulator operations.""" + acc = SumAccumulator(Reduce.SUM) + + acc.append(5.0) + acc.append(3.0) + assert acc.get_value() == 8.0 + + state = acc.get_state() + assert state["total"] == 8.0 + assert state["reduction_type"] == "sum" + + acc.reset() + assert acc.get_value() == 0.0 + + def test_max_accumulator(self): + """Test MaxAccumulator operations.""" + acc = MaxAccumulator(Reduce.MAX) + + acc.append(5.0) + acc.append(10.0) + acc.append(3.0) + assert acc.get_value() == 10.0 + + state = acc.get_state() + assert state["max_val"] == 10.0 + assert state["reduction_type"] == "max" + + def test_min_accumulator(self): + """Test MinAccumulator operations.""" + acc = MinAccumulator(Reduce.MIN) + + acc.append(5.0) + acc.append(10.0) + acc.append(3.0) + assert acc.get_value() == 3.0 + + state = acc.get_state() + assert state["min_val"] == 3.0 + assert state["reduction_type"] == "min" + + def test_std_accumulator(self): + """Test StdAccumulator operations.""" + acc = StdAccumulator(Reduce.STD) + + # Test with zero/one values + assert acc.get_value() == 0.0 + acc.append(5.0) + assert acc.get_value() == 0.0 # std of single value is 0 + + # Test with multiple values + acc.append(7.0) # values: 5, 7, mean=6, std=1 + assert abs(acc.get_value() - 1.0) < 0.001 + + state = acc.get_state() + assert state["sum"] == 12.0 + assert state["sum_sq"] == 74.0 # 5^2 + 7^2 = 25 + 49 = 74 + assert state["count"] == 2 + + @pytest.mark.parametrize( + "accumulator_class,states,expected", + [ + ( + MeanAccumulator, + [ + {"reduction_type": "mean", "sum": 10.0, "count": 2}, + {"reduction_type": "mean", "sum": 20.0, "count": 3}, + ], + 6.0, # (10+20) / (2+3) + ), + ( + SumAccumulator, + [ + {"reduction_type": "sum", "total": 10.0}, + {"reduction_type": "sum", "total": 15.0}, + ], + 25.0, + ), + ], + ) + def test_accumulator_state_reduction(self, accumulator_class, states, expected): + """Test cross-accumulator state reduction.""" + result = accumulator_class.get_reduced_value_from_states(states) + assert result == expected + + def test_reduce_enum_accumulator_mapping(self): + """Test that Reduce enum correctly maps to accumulator classes.""" + assert Reduce.MEAN.accumulator_class == MeanAccumulator + assert Reduce.SUM.accumulator_class == SumAccumulator + assert Reduce.MAX.accumulator_class == MaxAccumulator + assert Reduce.MIN.accumulator_class == MinAccumulator + assert Reduce.STD.accumulator_class == StdAccumulator + + class TestCriticalFixes: """Test critical production fixes from Diff 1.""" def test_uninitialized_push_logs_warning(self, mock_rank, caplog): """Test MetricCollector.push() logs warning when uninitialized.""" collector = MetricCollector() + metric = Metric("test", 1.0, Reduce.MEAN) # Should not raise error, just log warning and return - collector.push("test", 1.0, Reduce.MEAN) + collector.push(metric) assert any( "Metric logging backends" in record.message for record in caplog.records ) @@ -41,7 +270,7 @@ async def test_uninitialized_flush_logs_warning(self, mock_rank, caplog): collector = MetricCollector() # Should not raise error, just log warning and return empty dict - result = await collector.flush(step=1, return_state=True) + result = await collector.flush(global_step=1, return_state=True) assert result == {} assert any( "Cannot flush collected metrics" in record.message @@ -95,10 +324,12 @@ async def test_console_backend(self, mock_actor_name): backend = ConsoleBackend({}) - await backend.init(role="local") + await backend.init(role=BackendRole.LOCAL) # Test log - should not raise - await backend.log({"test": 1.0}, step=1) + # Create a test metric + test_metric = Metric("test", 1.0, Reduce.MEAN) + await backend.log([test_metric], global_step=1) await backend.finish() # Should not raise diff --git a/tests/unit_tests/observability/test_perf_tracker.py b/tests/unit_tests/observability/test_perf_tracker.py index 6af7331f1..01d1603d1 100644 --- a/tests/unit_tests/observability/test_perf_tracker.py +++ b/tests/unit_tests/observability/test_perf_tracker.py @@ -12,7 +12,7 @@ import pytest import torch -from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_CUDA +from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU from forge.observability.metrics import Reduce from forge.observability.perf_tracker import _TimerCPU, _TimerCUDA, trace, Tracer @@ -135,7 +135,7 @@ def test_comprehensive_workflow( if timer == "gpu" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - monkeypatch.setenv(METRIC_TIMER_USES_CUDA, str(timer == "gpu")) + monkeypatch.setenv(METRIC_TIMER_USES_GPU, str(timer == "gpu")) async def run_concurrent_tasks(): start_time = time.perf_counter() @@ -370,17 +370,17 @@ async def disabled_workflow(): ("false", _TimerCPU), ], ) - def test_metric_timer_uses_cuda_override( + def test_metric_timer_uses_gpu_override( self, env_value, expected_backend, monkeypatch ): - """Test METRIC_TIMER_USES_CUDA env var overrides timer parameter.""" + """Test METRIC_TIMER_USES_GPU env var overrides timer parameter.""" if env_value == "true" and not torch.cuda.is_available(): pytest.skip("CUDA not available") with patch("torch.cuda.is_available", return_value=True), patch( "forge.observability.perf_tracker.record_metric" ): - monkeypatch.setenv(METRIC_TIMER_USES_CUDA, env_value) + monkeypatch.setenv(METRIC_TIMER_USES_GPU, env_value) # Test with timer="cpu" (should be overridden by env) tracer = Tracer("env_test", timer="cpu")