diff --git a/apps/grpo/main.py b/apps/grpo/main.py index c64f00bc2..770c7b9ac 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -319,7 +319,7 @@ async def main(cfg: DictConfig): ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) await ts.initialize(strategy=ts.ControllerStorageVolumes()) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 53eec5cfb..c7a402b08 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -16,11 +16,12 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + per_rank_share_run: False console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index ca88b349a..d9466dffa 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -19,11 +19,12 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + per_rank_share_run: False console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index c46ee0620..95cd94e29 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -12,11 +12,12 @@ off_by_n: 1 # Off by one by default # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + per_rank_share_run: False console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index c823afb29..a0704d3d9 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -21,6 +21,7 @@ from forge.controller.launcher import BaseLauncher, get_launcher from forge.observability.metric_actors import get_or_create_metric_logger +from forge.observability.utils import detect_actor_name_from_call_stack from forge.types import ProcessConfig, ProvisionerConfig @@ -262,8 +263,10 @@ def bootstrap(env: dict[str, str]): self._proc_host_map[procs] = host_mesh - # Spawn local logging actor on each process and register with global logger - _ = await get_or_create_metric_logger(procs) + # Detect actor name and spawn local logging actor on each process + process_name = detect_actor_name_from_call_stack() + _ = await get_or_create_metric_logger(procs, process_name=process_name) + return procs async def host_mesh_from_proc(self, proc_mesh: ProcMesh): diff --git a/src/forge/env_constants.py b/src/forge/env_constants.py index 3adcdfc41..9c8905012 100644 --- a/src/forge/env_constants.py +++ b/src/forge/env_constants.py @@ -11,7 +11,7 @@ # Force all timing methods in forge.observability.perf_tracker.py to use # CPU timer if False or GPU timer if True. If unset, defaults to the assigned value to the function. -METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA" +METRIC_TIMER_USES_GPU = "METRIC_TIMER_USES_GPU" # Makes forge.observability.metrics.record_metric a no-op FORGE_DISABLE_METRICS = "FORGE_DISABLE_METRICS" diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index 52262eed5..f37dacebd 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -10,15 +10,15 @@ LocalFetcherActor, ) from .metrics import ( + BackendRole, ConsoleBackend, - # Utility functions get_actor_name_with_rank, get_logger_backend_class, - # Backend classes LoggerBackend, + LoggingMode, MaxAccumulator, MeanAccumulator, - # Accumulator classes + Metric, MetricAccumulator, MetricCollector, MinAccumulator, @@ -41,8 +41,12 @@ # Performance tracking "Tracer", "trace", + # Data classes + "Metric", + "BackendRole", # Enums "Reduce", + "LoggingMode", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index d67a66a83..57a723ec1 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -11,12 +11,16 @@ from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc from forge.observability.metrics import ( + BackendRole, get_logger_backend_class, LoggerBackend, + LoggingMode, MetricCollector, reduce_metrics_states, ) +from forge.observability.utils import detect_actor_name_from_call_stack + logger = logging.getLogger(__name__) _global_logger = None @@ -24,6 +28,7 @@ async def get_or_create_metric_logger( proc_mesh: ProcMesh | None = None, + process_name: str | None = None, ) -> "GlobalLoggingActor": """Initializes a LocalFetcherActor in the specified process mesh (or current process if None), if not already initialized, registers it with the GlobalLoggingActor and returns the @@ -37,6 +42,8 @@ async def get_or_create_metric_logger( Args: proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `monarch.actor.this_proc()`. + process_name: Optional meaningful process name (e.g., "TrainActor", "GeneratorActor") for logging. + If None, will auto-detect from call stack or default to "UnknownActor" if not found. Returns: GlobalLoggingActor: The global logging controller. @@ -54,8 +61,8 @@ async def get_or_create_metric_logger( # Initialize logging backends await mlogger.init_backends({ - "console": {"reduce_across_ranks": True}, - "wandb": {"project": "my_project", "reduce_across_ranks": False} + "console": {"logging_mode": "global_reduce"}, + "wandb": {"project": "my_project", "logging_mode": "per_rank_no_reduce"} }) # Initialize services... @@ -63,13 +70,17 @@ async def get_or_create_metric_logger( # Training loop for step in range(max_steps): - record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN) + record_metric("loss", 1.2, reduction_type=Reduce.MEAN) # ... training code with record_metric() calls ... await mlogger.flush(step) # Log metrics for this step # Shutdown await mlogger.shutdown() """ + + if process_name is None: + process_name = detect_actor_name_from_call_stack() + # Get or create the singleton global logger global _global_logger if _global_logger is None: @@ -78,14 +89,11 @@ async def get_or_create_metric_logger( ) global_logger = _global_logger - # Determine process context + # Sanity check that if we already have a LocalFetcherActor, + # it is registered with the global logger proc = proc_mesh if proc_mesh is not None else this_proc() - - # Check current state for consistency proc_has_local_fetcher = hasattr(proc, "_local_fetcher") global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc) - - # Consistency check: both should be in sync if proc_has_local_fetcher != global_logger_has_local_fetcher: raise ValueError( f"Inconsistent logging state for proc {proc}: " @@ -98,10 +106,10 @@ async def get_or_create_metric_logger( # Setup local_fetcher_actor if needed if not proc_has_local_fetcher: local_fetcher_actor = proc.spawn( - "local_fetcher_actor", LocalFetcherActor, global_logger + "local_fetcher_actor", LocalFetcherActor, global_logger, process_name ) await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) - proc._local_fetcher = local_fetcher_actor + proc._local_fetcher = local_fetcher_actor # pyre-ignore return global_logger @@ -114,19 +122,23 @@ class LocalFetcherActor(Actor): GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector """ - def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None: + def __init__( + self, + global_logger: Optional["GlobalLoggingActor"] = None, + process_name: str | None = None, + ) -> None: self.global_logger = global_logger - _is_initialized = False + self.process_name = process_name # Passed MetricCollector for logging @endpoint async def flush( - self, step: int, return_state: bool = False + self, global_step: int, return_state: bool = False ) -> Dict[str, Dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. This should only ever be called by the global logger. Args: - step (int): train step used by backends to align all metrics on the same x-axis + global_step (int): step used by backends to align all metrics on the same x-axis return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: @@ -134,7 +146,7 @@ async def flush( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ collector = MetricCollector() - result = await collector.flush(step, return_state=return_state) + result = await collector.flush(global_step, return_state=return_state) return result @endpoint @@ -142,10 +154,22 @@ async def init_backends( self, metadata_per_primary_backend: Dict[str, Dict[str, Any]], config: Dict[str, Any], + global_step: int = 0, ): - """Init local (per-rank) logger backends and MetricCollector.""" + """Init local (per-rank) logger backends and MetricCollector. + + Args: + metadata_per_primary_backend: Metadata from primary backends for shared state. + config: Backend configurations with logging modes and settings. + global_step: Initial step for metrics. + """ collector = MetricCollector() - await collector.init_backends(metadata_per_primary_backend, config) + await collector.init_backends( + metadata_per_primary_backend, + config, + global_step, + process_name=self.process_name, + ) @endpoint async def shutdown(self): @@ -155,7 +179,7 @@ async def shutdown(self): class GlobalLoggingActor(Actor): - """Coordinates metric logging across all ranks for every training step. + """Coordinates metric logging across all ranks for every step. Supports multiple logging backends (e.g., WandB, TensorBoard, etc.), for per-rank and/or global reduction logging modes. @@ -169,8 +193,8 @@ class GlobalLoggingActor(Actor): the per-rank MetricCollector. In summary, the flow is: - - GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector - - GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush + - GlobalLoggingActor.init_backends() -> LocalFetcherActor.init_backends() -> per-rank MetricCollector.init_backends() + - GlobalLoggingActor.flush() -> LocalFetcherActor.flush() -> per-rank MetricCollector.flush """ def __init__(self): @@ -179,48 +203,74 @@ def __init__(self): self.global_logger_backends: Dict[str, LoggerBackend] = {} self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {} + def _validate_backend_config( + self, backend_name: str, config: Dict[str, Any] + ) -> Dict[str, Any]: + """Validate and normalize backend configuration.""" + if "logging_mode" not in config: + logger.debug( + f"logging_mode not provided for backend {backend_name}. Defaulting to global_reduce." + ) + + mode_str = config.get("logging_mode", "global_reduce") + mode = LoggingMode(mode_str) + + # Validate per_rank_share_run configuration + share_run = config.get("per_rank_share_run", False) + if mode == LoggingMode.GLOBAL_REDUCE and share_run: + logger.warning( + f"{backend_name}: per_rank_share_run ignored in {mode.value} mode." + "Set it to False or change logging_mode to per rank." + ) + + return { + **config, + "logging_mode": mode, + } + @endpoint async def init_backends(self, config: Dict[str, Any]): - """ - Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors + """Sets config in global actor, initializes primary backends and eagerly initializes MetricCollectors in all registered fetchers. - A backend is always initialized in the controller (primary backend) and can be used as a logger or as a source - for metadata to be shared with per-rank backends, e.g. shared run IDs for wandb. - - The backend instantiation is controlled by the backend config flag `reduce_across_ranks`: if False, - a per-rank backend is initialized, i.e. if there are 2 ranks, each will have its own backend, - and will log independently, i.e. each rank will have its own run in wandb. - - Else, if True, the GlobalLoggingActor will fetch all local metrics collectors to get their states - and reduce them to a single value, which will be logged by the primary backend in this controller. + The backend instantiation is controlled by the logging_mode field. Primary backends + (instantiated in the controller) can provide metadata to be shared with secondary backends on ranks, + e.g. shared run IDs for WandB. For details on logging modes, see `forge.observability.metrics.LoggingMode`. Args: - config (Dict[str, Any]): Config for metric logging where keys are backend names, - e.g. {"console": {"reduce_across_ranks": True}, "wandb": {"reduce_across_ranks": False}} + config (Dict[str, Any]): Config for metric logging where keys are backend names. + Examples: + - {"console": {"logging_mode": "global_reduce"}} + - {"wandb": {"logging_mode": "per_rank_no_reduce", "project": "my_project", "per_rank_share_run": True}} + + Raises: + ValueError: If backend config is invalid or missing required fields. """ - self.config = config + self.config = {} + # Validate and normalize each backend config for backend_name, backend_config in config.items(): + self.config[backend_name] = self._validate_backend_config( + backend_name, backend_config + ) + + # Initialize backends based on logging mode + for backend_name, backend_config in self.config.items(): + mode = backend_config["logging_mode"] + backend = get_logger_backend_class(backend_name)(backend_config) - await backend.init(role="global") - - # Extract metadata from primary logger to be shared with secondary loggers - # and store it - reduce_across_ranks = backend_config.get("reduce_across_ranks", True) - if not reduce_across_ranks: - primary_backend_metadata = ( - backend.get_metadata_for_secondary_ranks() or {} - ) - self.metadata_per_primary_backend[ - backend_name - ] = primary_backend_metadata + await backend.init(role=BackendRole.GLOBAL) + + # Extract metadata for per-rank shared modes + if mode != LoggingMode.GLOBAL_REDUCE: + primary_metadata = backend.get_metadata_for_secondary_ranks() or {} + self.metadata_per_primary_backend[backend_name] = primary_metadata - # Store global logger backends - if reduce_across_ranks: + # Store global backends for later flush + if mode == LoggingMode.GLOBAL_REDUCE: self.global_logger_backends[backend_name] = backend - # Eager init collectors on all registered fetchers in parallel, passing primary states and config + # Initialize per rank fetchers if self.fetchers: tasks = [ fetcher.init_backends.call( @@ -255,13 +305,13 @@ async def deregister_fetcher(self, name: str | ProcMesh): del self.fetchers[name] @endpoint - async def flush(self, step: int): + async def flush(self, global_step: int): """ Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors log to local backends and return states if needed for cross-rank reduction. Args: - step (int): Global step for logging. + global_step (int): step for logging. """ if not self.fetchers: return @@ -269,58 +319,63 @@ async def flush(self, step: int): config = self.config if config is None: logger.warning( - "GlobalLoggingActor flush() called before init_backends(). " - "No backends will be flushed." + "Cannot flush collected metrics. GlobalLoggingActor.flush() called before init_backends()." + " No backends will be flushed. Please call in your main file:\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" ) return - # if reduce_across_ranks=True, we need to reduce the states from all ranks - # and log with the primary backend + + # Check if we need states for GLOBAL_REDUCE backends requires_reduce = any( - backend_config.get("reduce_across_ranks", True) + backend_config["logging_mode"] == LoggingMode.GLOBAL_REDUCE for backend_config in config.values() ) - logger.debug(f"Global flush for step {step}: {len(self.fetchers)} fetchers") + logger.debug( + f"Global flush for step {global_step}: {len(self.fetchers)} fetchers" + ) # Broadcast flush to all fetchers results = await asyncio.gather( *[ - f.flush.call(step, return_state=requires_reduce) + f.flush.call(global_step, return_state=requires_reduce) for f in self.fetchers.values() ], return_exceptions=True, ) if requires_reduce: - # Handle exceptions and extract values from ValueMesh results - all_local_states = [] - for result in results: - if isinstance(result, BaseException): - logger.warning(f"Flush failed on a fetcher: {result}") - continue - - # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}] - for gpu_info, local_metric_state in result.items(): - if isinstance(local_metric_state, dict): - all_local_states.append(local_metric_state) - else: - logger.warning( - f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}" - ) + + def extract_values_from_valuemesh(results): + all_local_states = [] + for result in results: + if isinstance(result, BaseException): + logger.warning(f"Flush failed on a fetcher: {result}") + continue + + # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}] + for gpu_info, local_metric_state in result.items(): + if isinstance(local_metric_state, dict): + all_local_states.append(local_metric_state) + else: + logger.warning( + f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}" + ) + return all_local_states + + all_local_states = extract_values_from_valuemesh(results) if not all_local_states: - logger.warning(f"No states to reduce for step {step}") + logger.warning(f"No states to reduce for global_step {global_step}") return - # Reduce + # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) - # Log to each global logger_backend - for ( - logger_backend_name, - logger_backend, - ) in self.global_logger_backends.items(): - await logger_backend.log(reduced_metrics, step) + # Log to global backends + for backend_name, backend in self.global_logger_backends.items(): + await backend.log_batch(reduced_metrics, global_step) @endpoint def has_fetcher(self, name: str | ProcMesh) -> bool: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 990a301e0..522074f42 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -5,17 +5,58 @@ # LICENSE file in the root directory of this source tree. import logging - import os from abc import ABC, abstractmethod +from dataclasses import dataclass +from datetime import datetime from enum import Enum from typing import Any, Dict, List, Optional -from monarch.actor import context, current_rank +import pytz +from monarch.actor import current_rank + +from forge.observability.utils import get_actor_name_with_rank logger = logging.getLogger(__name__) +class BackendRole: + """Backend role constants for metric logging actors. + + Defines whether an actor operates as a local (per-rank) or global (controller) role + in the distributed metrics collection system. + """ + + LOCAL: str = "local" + GLOBAL: str = "global" + + +class LoggingMode(Enum): + """Metric logging behavior for distributed training scenarios. + + Each mode serves different observability needs: + + GLOBAL_REDUCE = "global_reduce" + Best for: Metrics that are best visualized as a single value per step. + Behavior: All ranks accumulate → controller reduces → single log entry + Example use: 8 ranks training, want 1 loss value per step averaged across all + + PER_RANK_REDUCE = "per_rank_reduce" + Best for: Per-rank performance metrics, debugging individual rank behavior + Behavior: Each rank accumulates + logs its own reduced values + Example use: Monitor GPU utilization per rank, get 8 separate log entries per step + + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + Best for: Real-time streaming, time-series debugging + Behavior: Raw values logged immediately on record_metric() calls + Example use: See what every rank is doing in real time. + """ + + GLOBAL_REDUCE = "global_reduce" + PER_RANK_REDUCE = "per_rank_reduce" + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + + class Reduce(Enum): MEAN = "mean" SUM = "sum" @@ -35,58 +76,26 @@ def accumulator_class(self): return mapping[self] -def get_actor_name_with_rank() -> str: - """ - Extracts actor information from Monarch context to form a logging name. +@dataclass +class Metric: + """Container for metric data including key, value, reduction type, and timestamp. - Returns: - str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). - Falls back to "UnknownActor" if context unavailable. + Timestamp is automatically set to current EST time if not provided. """ - # Add more defensive checks - ctx = context() - if ctx is None or ctx.actor_instance is None: - logger.warning("Context unavailable, using fallback actor name for logging.") - return "UnknownActor" - - actor_instance = ctx.actor_instance - rank = current_rank() - - actor_id_full = str(actor_instance.actor_id) - - # Parse the actor_id - parts = actor_id_full.split(".") - rank_name = "UnknownActor" # fallback - if len(parts) >= 2: - world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" - actor_part = parts[1] # e.g., "TestActorConfigured[0]" - - # Extract world ID and proc rank - world_id = world_part.split("[")[0] if "[" in world_part else world_part - - # Extract clean actor name (remove "Configured" suffix if present) - if "[" in actor_part: - actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" - if actor_name.endswith("Configured"): - actor_name = actor_name[:-10] # Remove "Configured" - else: - actor_name = actor_part - - # Use last 4 characters of world_id as replica identifier - # This is deterministic, readable, and works for any number of replicas - replica_id = world_id[-4:] if len(world_id) >= 4 else world_id - - # Use current_rank().rank as the local rank within the replica - local_rank = rank.rank - rank_name = f"{actor_name}_{replica_id}_r{local_rank}" + key: str + value: Any + reduction: Reduce + timestamp: Optional[float] = None - return rank_name + def __post_init__(self): + if self.timestamp is None: + # Always record in UTC timezone + self.timestamp = datetime.now(pytz.UTC).timestamp() def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: - """ - Records a metric value for later reduction and logging. + """Thin wrapper to send metrics to per-rank local MetricCollectors. Relies on a per-rank MetricCollector singleton for ease of use, i.e. call `record_metric` anywhere in the code without moving the @@ -101,16 +110,18 @@ def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None Can be disabled globally by setting the environment variable `FORGE_DISABLE_METRICS=true`. """ - # Skip metrics collection if disabled for tests + # Skip metrics collection if os.getenv("FORGE_DISABLE_METRICS", "false").lower() == "true": return + # timestamp is added automatically by the Metric class + metric = Metric(key=key, value=value, reduction=reduction) collector = MetricCollector() - collector.push(key, value, reduction) + collector.push(metric) -def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, Any]: - """Reduce metric accumulators states to a single value per metric. +def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> List[Metric]: + """Reduce metric accumulators states to a list of metrics. Can be used when reducing metrics across ranks or services, as merging states is more precise than merging locally reduced metrics. @@ -120,7 +131,7 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. Returns: - Dict[str, Any]: Dictionary with format {metric_key: reduced_value} + List[Metric]: List of reduced metrics Example: states = [ @@ -128,18 +139,18 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, {"loss": {"count": 10, "sum": 16, "reduction_type": Reduce.MEAN}}, ] reduce_metrics_states(states) - >>> {"loss": 2.0} + >>> [Metric(key="loss", value=2.0, reduction=Reduce.MEAN)] Raises: ValueError: on mismatched reduction types for the same metric key. """ if not states: - return {} + return [] # Collect unique keys across all all_keys = set(k for state in states for k in state) - reduced_metrics = {} + reduced_metrics = [] for key in all_keys: metric_states = [state.get(key) for state in states if key in state] if not metric_states: @@ -158,7 +169,14 @@ def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, metric_accumulator = Reduce(first_reduction_type).accumulator_class reduced_value = metric_accumulator.get_reduced_value_from_states(metric_states) - reduced_metrics[key] = reduced_value + + # Create Metric object with reduced value + metric = Metric( + key=key, + value=reduced_value, + reduction=Reduce(first_reduction_type), + ) + reduced_metrics.append(metric) return reduced_metrics @@ -367,7 +385,7 @@ class MetricCollector: - Init via GlobalLoggingActor -> LocalFetcherActor -> per-rank MetricCollector; - GlobalLoggingActor flushes trigger reductions and log for any locally setup backend. Can optionally also return non-reduced states for global aggregation. This can be different for each backend. - - Resets accumulators post-flush to avoid leaks across train steps; + - Resets accumulators post-flush to avoid leaks across steps; """ _instances: Dict[int, "MetricCollector"] = {} @@ -395,62 +413,129 @@ def __init__(self): self.accumulators: Dict[str, MetricAccumulator] = {} self.rank = current_rank().rank - self.logger_backends: List[LoggerBackend] = [] + self.per_rank_reduce_backends: List[LoggerBackend] = [] + self.per_rank_no_reduce_backends: List[LoggerBackend] = [] + self.global_step: int = 0 # Updated on flush self._is_initialized = False async def init_backends( self, metadata_per_primary_backend: Optional[Dict[str, Dict[str, Any]]], config: Dict[str, Any], + global_step: int = 0, + process_name: str | None = None, ) -> None: - """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, - the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated - once globally. + """Initialize per-rank logger backends and MetricCollector state. + + A logger backend is represented by a backend class (e.g. WandBBackend, ConsoleBackend). + Backends are categorized by their logging_mode. For details, see `forge.observability.metrics.LoggingMode`. Args: metadata_per_primary_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from primary - logger backend, e.g., {"wandb": {"run_id": "abc123"}}. - config (Dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. + logger backends for backends that require shared state, e.g., + {"wandb": {"shared_run_id": "abc123"}} for shared WandB runs across ranks. + config (Dict[str, Any]): Backend configurations where each key is a backend name + and value contains logging_mode and backend-specific settings. + e.g., {"wandb": {"logging_mode": "per_rank_no_reduce", "project": "my_proj"}} + global_step (int, default 0): Initial step for immediate logging. This allows + restarting from checkpoints with correct step numbering. + process_name (str | None): The meaningful process name for logging. """ if self._is_initialized: logger.debug(f"Rank {self.rank}: MetricCollector already initialized") return - # instantiate local backends if any + # Initialize step tracking for immediate logging + self.global_step = global_step + + self.per_rank_reduce_backends: List[LoggerBackend] = [] + self.per_rank_no_reduce_backends: List[LoggerBackend] = [] + + # Initialize backends based on logging mode for backend_name, backend_config in config.items(): - if backend_config.get("reduce_across_ranks", True): - continue # Skip local backend instantiation and use global instead + mode = LoggingMode(backend_config["logging_mode"]) - # get metadata from primary backend if any + # Skip local instantiation for GLOBAL_REDUCE + # Backend will be instantiated in GlobalLoggingActor + if mode == LoggingMode.GLOBAL_REDUCE: + continue + + # Get primary metadata if needed primary_metadata = {} if metadata_per_primary_backend: primary_metadata = metadata_per_primary_backend.get(backend_name, {}) - # instantiate local backend - logger_backend = get_logger_backend_class(backend_name)(backend_config) - await logger_backend.init( - role="local", primary_logger_metadata=primary_metadata + # Instantiate backend + backend = get_logger_backend_class(backend_name)(backend_config) + await backend.init( + role=BackendRole.LOCAL, + primary_logger_metadata=primary_metadata, + process_name=process_name, ) - self.logger_backends.append(logger_backend) + + # Categorize by logging mode + if mode == LoggingMode.PER_RANK_NO_REDUCE: + self.per_rank_no_reduce_backends.append(backend) + else: + self.per_rank_reduce_backends.append(backend) self._is_initialized = True - def push(self, key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: + def push(self, metric: Metric) -> None: + """Process a metric according to configured logging modes. + + Behavior depends on backend modes: + - PER_RANK_NO_REDUCE: Stream metric immediately to backends + - PER_RANK_REDUCE/GLOBAL_REDUCE: Accumulate for per step batch logging + + Args: + metric: Metric dataclass + + Example: + collector = MetricCollector() + metric = Metric("loss", 0.5, Reduce.MEAN) + collector.push(metric) # Streams immediately if no_reduce, else accumulates + """ if not self._is_initialized: - raise ValueError("Collector not initialized—call init first") + from forge.util.logging import log_once + + log_once( + logger, + level=logging.WARNING, + msg=( + "Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized." + " This happens when you try to use `record_metric` before calling `init_backends`." + " To disable this warning, please call in your main file:\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" + "or set env variable `FORGE_DISABLE_METRICS=True`" + ), + ) + return - if key not in self.accumulators: - self.accumulators[key] = reduction.accumulator_class(reduction) + # Validate metric object + if not isinstance(metric, Metric): + raise TypeError(f"Expected {Metric} object, got {metric}") + + # For PER_RANK_NO_REDUCE backends: stream immediately + for backend in self.per_rank_no_reduce_backends: + backend.log_stream(metric=metric, global_step=self.global_step) - self.accumulators[key].append(value) + # Always accumulate for reduction and state return + key = metric.key + if key not in self.accumulators: + self.accumulators[key] = metric.reduction.accumulator_class( + metric.reduction + ) + self.accumulators[key].append(metric.value) async def flush( - self, step: int, return_state: bool = False + self, global_step: int, return_state: bool = False ) -> Dict[str, Dict[str, Any]]: """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. Args: - step (int): Step used by backends to align metrics on the same x-axis + global_step (int): step used by backends to align metrics on the same x-axis return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: @@ -458,14 +543,22 @@ async def flush( e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: - logger.debug( - f"Collector not yet initialized for {get_actor_name_with_rank()}. Call init_backends first." + from forge.util.logging import log_once + + log_once( + logger, + level=logging.WARNING, + msg="Cannot flush collected metrics. MetricCollector.flush() called before init_backends()." + "\nPlease call in your main file:\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" + "before calling `flush`", ) return {} if not self.accumulators: logger.debug( - f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for step {step}" + f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for step {global_step}" ) return {} @@ -475,29 +568,30 @@ async def flush( states[key] = acc.get_state() acc.reset() - # Reduce metrics from states for logging if any per-rank backend - if self.logger_backends: - metrics = {} - for key, state in states.items(): - acc_class = Reduce(state["reduction_type"]).accumulator_class - metrics[key] = acc_class.get_reduced_value_from_states([state]) + # Log to PER_RANK_REDUCE backends only (NO_REDUCE already logged in push) + if self.per_rank_reduce_backends: + metrics_for_backends = reduce_metrics_states([states]) - # Log to local logger_backends - for logger_backend in self.logger_backends: - await logger_backend.log(metrics, step) + # Log to PER_RANK_REDUCE backends + for backend in self.per_rank_reduce_backends: + await backend.log_batch(metrics_for_backends, global_step) + + # Update step (used by NO_REDUCE backends in push) + self.global_step = global_step + 1 return states if return_state else {} async def shutdown(self): """Shutdown logger_backends if initialized.""" + if not self._is_initialized: logger.debug( f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" ) return - for logger_backend in self.logger_backends: - await logger_backend.finish() + for backend in self.per_rank_reduce_backends + self.per_rank_no_reduce_backends: + await backend.finish() ########### @@ -516,12 +610,13 @@ async def init( self, role: str, primary_logger_metadata: Optional[Dict[str, Any]] = None, + process_name: str | None = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). Args: - role (str): "global" (controller/primary) or "local" (per-rank/secondary). + role (BackendRole): "global" (controller/primary) or "local" (per-rank/secondary). Can be used to behave differently for primary vs secondary roles. primary_logger_metadata (Optional[Dict[str, Any]]): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. @@ -532,7 +627,24 @@ async def init( primary_logger_metadata = {} pass - async def log(self, metrics: Dict[str, Any], step: int) -> None: + async def log_batch( + self, metrics: List[Metric], global_step: int, *args, **kwargs + ) -> None: + """Log batch of accumulated metrics to backend""" + pass + + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + """Stream single metric to backend immediately. + + NOTE: This method is called synchronously. + If your backend requires async I/O operations: + - Use asyncio.create_task() for fire-and-forget logging + - Consider internal buffering to avoid blocking the caller + + Example for async backend: + def log_stream(self, metric, global_step): + asyncio.create_task(self._async_log(metric, global_step)) + """ pass async def finish(self) -> None: @@ -553,18 +665,24 @@ async def init( self, role: str, primary_logger_metadata: Optional[Dict[str, Any]] = None, + process_name: str | None = None, ) -> None: - self.prefix = ( - get_actor_name_with_rank() - if self.logger_backend_config.get("reduce_across_ranks", True) - else "GLOBAL" + pass + + async def log_batch( + self, metrics: List[Metric], global_step: int, *args, **kwargs + ) -> None: + metrics_str = "\n".join( + f" {metric.key}: {metric.value}" + for metric in sorted(metrics, key=lambda m: m.key) + ) + logger.info( + f"=== [METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" ) - async def log(self, metrics: Dict[str, Any], step: int) -> None: - logger.info(f"=== [{self.prefix}] - METRICS STEP {step} ===") - for key, value in sorted(metrics.items()): - logger.info(f" {key}: {value}") - logger.info("==============================\n") + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + """Stream metric to console immediately.""" + logger.info(f"{metric.key}: {metric.value}") async def finish(self) -> None: pass @@ -572,18 +690,17 @@ async def finish(self) -> None: class WandbBackend(LoggerBackend): """ - Weights & Biases logging backend for distributed training. + Weights & Biases logging backend. - Supports 3 types of modes as described in https://docs.wandb.ai/guides/track/log/distributed-training/: - Track a single process: reduce_across_ranks=True - Track each process separately: reduce_across_ranks=False, share_run_id=False - Track all processes to a single run: reduce_across_ranks=False, share_run_id=True + For logging mode details, see `forge.observability.metrics.LoggingMode` documentation. + + More details on wandb distributed logging here: https://docs.wandb.ai/guides/track/log/distributed-training/ Configuration: - reduce_across_ranks (bool, default True): If True, log reduced metrics only from controller (global mode). - If False, enables per-rank logging; then use share_run_id to pick mode. - share_run_id (bool, default False): Only used if reduce_across_ranks=False. - True -> shared run across ranks; False -> separate runs per rank. + logging_mode (LoggingMode): Determines logging behavior + per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks. + If true, then a single wandb is created and all ranks log to it. Its particularly useful if + logging with no_reduce to capture a time based stream of information. Not recommended if reducing values. project (str): WandB project name group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" """ @@ -594,44 +711,43 @@ def __init__(self, logger_backend_config: Dict[str, Any]): self.group = logger_backend_config.get("group", "experiment_group") self.name = None self.run = None - self.reduce_across_ranks = logger_backend_config.get( - "reduce_across_ranks", True - ) - self.share_run_id = logger_backend_config.get("share_run_id", False) + self.logging_mode = LoggingMode(logger_backend_config["logging_mode"]) + self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False) async def init( self, role: str, primary_logger_metadata: Optional[Dict[str, Any]] = None, + process_name: str | None = None, ) -> None: if primary_logger_metadata is None: primary_logger_metadata = {} - if role not in ["global", "local"]: + if role not in [BackendRole.GLOBAL, BackendRole.LOCAL]: raise ValueError( - f"Invalid role {role} for WandbBackend init. Must be 'global' or 'local'." + f"Invalid role {role} for WandbBackend init. Must be '{BackendRole.GLOBAL}' or '{BackendRole.LOCAL}'." ) self.name = ( - get_actor_name_with_rank() if role == "local" else "global_controller" + get_actor_name_with_rank(process_name) + if role == BackendRole.LOCAL + else "Controller" ) - # Default global mode: only inits on controller - if self.reduce_across_ranks: - if role != "global": - logger.debug( - f"Skipped init for global mode (reduce_across_ranks=True) and {role} role." - ) + # GLOBAL_REDUCE mode: only inits on controller + if self.logging_mode == LoggingMode.GLOBAL_REDUCE: + if role != BackendRole.GLOBAL: + logger.warning(f"Skipped init for GLOBAL_REDUCE mode and {role} role.") return await self._init_global() - # Per-rank modes based on share_run_id bool - elif role == "global" and self.share_run_id: + # Per-rank modes based on per_rank_share_run bool + elif role == BackendRole.GLOBAL and self.per_rank_share_run: await self._init_shared_global() - elif role == "local": - if self.share_run_id: + elif role == BackendRole.LOCAL: + if self.per_rank_share_run: await self._init_shared_local(primary_logger_metadata) else: await self._init_per_rank() @@ -652,34 +768,66 @@ async def _init_shared_global(self): settings = wandb.Settings( mode="shared", x_primary=True, x_label="controller_primary" ) - self.run = wandb.init(project=self.project, group=self.group, settings=settings) + + self.run = wandb.init( + project=self.project, + group=self.group, + settings=settings, + ) async def _init_shared_local(self, primary_metadata: Dict[str, Any]): import wandb + from wandb.sdk.lib.service import service_token shared_id = primary_metadata.get("shared_run_id") if shared_id is None: raise ValueError( f"Shared ID required but not provided for {self.name} backend init" ) + + # Clear any stale service tokens that might be pointing to dead processes + # In multiprocessing environments, WandB service tokens can become stale and point + # to dead service processes. This causes wandb.init() to hang indefinitely trying + # to connect to non-existent services. Clearing forces fresh service connection. + service_token.clear_service_in_env() + settings = wandb.Settings(mode="shared", x_primary=False, x_label=self.name) self.run = wandb.init( - id=shared_id, - project=self.project, - group=self.group, - settings=settings, + id=shared_id, project=self.project, group=self.group, settings=settings ) - async def log(self, metrics: Dict[str, Any], step: int) -> None: - if self.run: - log_data = {**metrics, "global_step": step} - self.run.log(log_data) - logger.info(f"WandbBackend: Logged {len(metrics)} metrics at step {step}") - else: + async def log_batch( + self, metrics: List[Metric], global_step: int, *args, **kwargs + ) -> None: + if not self.run: logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") + return + + # Convert metrics to WandB log format + log_data = {"step": global_step} + for metric in metrics: + log_data[metric.key] = metric.value + + self.run.log(log_data) + logger.info( + f"WandbBackend: Logged {len(metrics)} metrics at step {global_step}" + ) + + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + """Stream single metric to WandB with both step and timestamp.""" + if not self.run: + return + + # Log with both step and timestamp - users can choose x-axis in WandB UI + log_data = { + metric.key: metric.value, + "global_step": global_step, + "_timestamp": metric.timestamp, + } + self.run.log(log_data) def get_metadata_for_secondary_ranks(self) -> Dict[str, Any]: - if self.run and not self.reduce_across_ranks and self.share_run_id: + if self.run and self.per_rank_share_run: return {"shared_run_id": self.run.id} return {} diff --git a/src/forge/observability/perf_tracker.py b/src/forge/observability/perf_tracker.py index e85b81e26..47577d916 100644 --- a/src/forge/observability/perf_tracker.py +++ b/src/forge/observability/perf_tracker.py @@ -15,7 +15,7 @@ import torch -from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_CUDA +from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU from forge.observability.metrics import record_metric, Reduce # Thread-local memory tracking state @@ -125,7 +125,7 @@ def start(self) -> None: # Start timing (always enabled) time_with_gpu_events = ( - os.getenv(METRIC_TIMER_USES_CUDA, str(self.time_with_gpu)).lower() == "true" + os.getenv(METRIC_TIMER_USES_GPU, str(self.time_with_gpu)).lower() == "true" ) and torch.cuda.is_available() self._timer = _TimerCUDA() if time_with_gpu_events else _TimerCPU() self._timer.start() diff --git a/src/forge/observability/utils.py b/src/forge/observability/utils.py new file mode 100644 index 000000000..f9fc18014 --- /dev/null +++ b/src/forge/observability/utils.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +from monarch.actor import context, current_rank + +logger = logging.getLogger(__name__) + + +def detect_actor_name_from_call_stack() -> str: + """Detect ForgeActor subclass name from call stack. + + Returns: + str: Actor name, defaulting to "UnknownActor" if not found. + """ + try: + import inspect + + frame = inspect.currentframe() + frame_count = 0 + + while frame: + frame = frame.f_back + if not frame: + break + + frame_count += 1 + if frame_count > 20: # Prevent infinite loops + break + + # Check for 'self' (instance method calls) + if "self" in frame.f_locals: + obj = frame.f_locals["self"] + if hasattr(obj, "__class__") and hasattr(obj.__class__, "__mro__"): + for base in obj.__class__.__mro__: + if base.__name__ == "ForgeActor": + return obj.__class__.__name__ + + # Check for 'cls' (class method calls) + if "cls" in frame.f_locals: + cls = frame.f_locals["cls"] + if hasattr(cls, "__mro__"): + for base in cls.__mro__: + if base.__name__ == "ForgeActor": + return cls.__name__ + + except Exception as e: + logger.debug(f"Call stack detection failed: {e}") + + return "UnknownActor" + + +def get_actor_name_with_rank(actor_name: Optional[str] = None) -> str: + """ + Extracts actor information from Monarch context to form a logging name. + + Args: + actor_name: Optional actor name to use. If None, will auto-detect from call stack. + + Returns: + str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). + Falls back to "UnknownActor" if context unavailable. + """ + ctx = context() + if ctx is None or ctx.actor_instance is None: + logger.warning("Context unavailable, using fallback actor name for logging.") + return "UnknownActor" + + actor_instance = ctx.actor_instance + rank = current_rank() + actor_id_full = str(actor_instance.actor_id) + + # Parse the actor_id + parts = actor_id_full.split(".") + if len(parts) < 2: + return "UnknownActor" + + world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" + actor_part = parts[1] # e.g., "TestActorConfigured[0]" + + # Use provided actor name or auto-detect from call stack + if actor_name: + final_actor_name = actor_name + else: + final_actor_name = detect_actor_name_from_call_stack() + + # Use last 4 characters of world_id as replica identifier + world_id = world_part.split("[")[0] if "[" in world_part else world_part + replica_id = world_id[-4:] if len(world_id) >= 4 else world_id + + return f"{final_actor_name}_{replica_id}_r{rank.rank}" diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index d999fb700..2d291768b 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -7,7 +7,7 @@ import asyncio import logging -import time +from datetime import datetime from forge.controller.actor import ForgeActor from forge.controller.provisioner import shutdown @@ -17,7 +17,13 @@ from monarch.actor import current_rank, endpoint -logging.basicConfig(level=logging.DEBUG) +logging.basicConfig( + level=logging.DEBUG, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logging.getLogger("forge.observability.metrics").setLevel(logging.INFO) +logging.getLogger("forge.observability.metric_actors").setLevel(logging.INFO) +# Reduce wandb logging noise +logging.getLogger("wandb").setLevel(logging.WARNING) class TrainActor(ForgeActor): @@ -78,35 +84,34 @@ async def generate_step(self, step: int, substep: int): # Main async def main(): - """Example demonstrating distributed metric logging with different backends.""" - group = f"grpo_exp_{int(time.time())}" + group = "time" + str(int(datetime.now().timestamp())) # Config format: {backend_name: backend_config_dict} - # Each backend can specify reduce_across_ranks to control distributed logging behavior config = { - "console": {"reduce_across_ranks": True}, + "console": {"logging_mode": "global_reduce"}, "wandb": { - "project": "my_project", + "project": "toy_metrics", "group": group, - "reduce_across_ranks": False, - # Only useful if NOT reduce_across_ranks. - "share_run_id": False, # Share run ID across ranks -- Not recommended. + "logging_mode": "per_rank_reduce", # global_reduce, per_rank_reduce, per_rank_no_reduce + "per_rank_share_run": False, }, } service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(config) - # Spawn services first (triggers registrations via provisioner hook) + # Spawn services (will register fetchers) trainer = await TrainActor.options(**service_config).as_service() generator = await GeneratorActor.options(**service_config).as_service() for i in range(3): print(f"\n=== Global Step {i} ===") + record_metric("main/global_step", 1, Reduce.MEAN) await trainer.train_step.fanout(i) for sub in range(3): await generator.generate_step.fanout(i, sub) + await asyncio.sleep(0.1) await mlogger.flush.call_one(i) # shutdown diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 0f3ce662c..19b5621c1 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -33,7 +33,7 @@ async def run(cfg: DictConfig): ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) if (prompt := cfg.get("prompt")) is None: diff --git a/tests/unit_tests/observability/conftest.py b/tests/unit_tests/observability/conftest.py new file mode 100644 index 000000000..aa95de277 --- /dev/null +++ b/tests/unit_tests/observability/conftest.py @@ -0,0 +1,133 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared fixtures and mocks for observability unit tests.""" + +from unittest.mock import MagicMock, patch + +import pytest +from forge.observability.metrics import LoggerBackend, MetricCollector + + +class MockBackend(LoggerBackend): + """Mock backend for testing metrics logging without external dependencies.""" + + def __init__(self, logger_backend_config=None): + super().__init__(logger_backend_config or {}) + self.logged_metrics = [] + self.immediate_metrics = [] + self.init_called = False + self.finish_called = False + self.metadata = {} + + async def init(self, role="local", primary_logger_metadata=None, process_name=None): + self.init_called = True + self.role = role + self.primary_logger_metadata = primary_logger_metadata or {} + self.process_name = process_name + + def log_stream(self, metric, global_step, *args, **kwargs): + self.immediate_metrics.append((metric, global_step)) + + async def log_batch(self, metrics, global_step, *args, **kwargs): + for metric in metrics: + self.logged_metrics.append((metric, global_step)) + + async def finish(self): + self.finish_called = True + + def get_metadata_for_secondary_ranks(self): + return self.metadata + + +@pytest.fixture(autouse=True) +def clear_metric_collector_singletons(): + """Clear MetricCollector singletons before each test to avoid state leakage.""" + MetricCollector._instances.clear() + yield + MetricCollector._instances.clear() + + +@pytest.fixture(autouse=True) +def clean_metrics_environment(): + """Ensure clean environment state for metrics tests.""" + import os + + # Save original environment state + original_env = os.environ.get("FORGE_DISABLE_METRICS") + + # Set default state for tests (metrics enabled) + if "FORGE_DISABLE_METRICS" in os.environ: + del os.environ["FORGE_DISABLE_METRICS"] + + yield + + # Restore original environment state + if original_env is not None: + os.environ["FORGE_DISABLE_METRICS"] = original_env + elif "FORGE_DISABLE_METRICS" in os.environ: + del os.environ["FORGE_DISABLE_METRICS"] + + +@pytest.fixture +def mock_rank(): + """Mock current_rank function with configurable rank.""" + with patch("forge.observability.metrics.current_rank") as mock: + rank_obj = MagicMock() + rank_obj.rank = 0 + mock.return_value = rank_obj + yield mock + + +@pytest.fixture +def mock_actor_context(): + """Mock Monarch actor context for testing actor name generation.""" + with patch("forge.observability.metrics.context") as mock_context, patch( + "forge.observability.metrics.current_rank" + ) as mock_rank: + + # Setup mock context + ctx = MagicMock() + actor_instance = MagicMock() + actor_instance.actor_id = "_1rjutFUXQrEJ[0].TestActorConfigured[0]" + ctx.actor_instance = actor_instance + mock_context.return_value = ctx + + # Setup mock rank + rank_obj = MagicMock() + rank_obj.rank = 0 + mock_rank.return_value = rank_obj + + yield { + "context": mock_context, + "rank": mock_rank, + "expected_name": "TestActor_0XQr_r0", + } + + +@pytest.fixture +def initialized_collector(): + """Create an initialized MetricCollector with mock backends for testing.""" + with patch("forge.observability.metrics.current_rank") as mock_rank: + mock_rank.return_value = MagicMock(rank=0) + + MetricCollector._instances.clear() + collector = MetricCollector() + + # Setup mock backends + no_reduce_backend = MockBackend() + reduce_backend = MockBackend() + + collector._is_initialized = True + collector.per_rank_no_reduce_backends = [no_reduce_backend] + collector.per_rank_reduce_backends = [reduce_backend] + collector.global_step = 0 + + yield { + "collector": collector, + "no_reduce_backend": no_reduce_backend, + "reduce_backend": reduce_backend, + } diff --git a/tests/unit_tests/observability/test_metric_actors.py b/tests/unit_tests/observability/test_metric_actors.py new file mode 100644 index 000000000..c2b0a2992 --- /dev/null +++ b/tests/unit_tests/observability/test_metric_actors.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Optimized unit tests for metric actors functionality.""" + +import pytest + +from forge.observability.metric_actors import ( + get_or_create_metric_logger, + GlobalLoggingActor, + LocalFetcherActor, +) +from forge.observability.metrics import LoggingMode +from monarch.actor import this_host + + +@pytest.fixture +def global_logger(): + """Create a GlobalLoggingActor for testing.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestGlobalLogger", GlobalLoggingActor) + + +@pytest.fixture +def local_fetcher(global_logger): + """Create a LocalFetcherActor linked to global logger.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestLocalFetcher", LocalFetcherActor, global_logger) + + +class TestBasicOperations: + """Test basic operations for actors.""" + + @pytest.mark.asyncio + async def test_local_fetcher_flush(self, local_fetcher): + """Test LocalFetcherActor flush operations.""" + result_with_state = await local_fetcher.flush.call_one( + global_step=1, return_state=True + ) + assert result_with_state == {} + + result_without_state = await local_fetcher.flush.call_one( + global_step=1, return_state=False + ) + assert result_without_state == {} + + @pytest.mark.asyncio + async def test_global_logger_basic_ops(self, global_logger): + """Test GlobalLoggingActor basic operations.""" + count = await global_logger.get_fetcher_count.call_one() + assert count >= 0 + + has_fetcher = await global_logger.has_fetcher.call_one("nonexistent") + assert has_fetcher is False + + # Global logger flush (should not raise error) + await global_logger.flush.call_one(global_step=1) + + @pytest.mark.asyncio + async def test_backend_init(self, local_fetcher): + """Test backend initialization and shutdown.""" + metadata = {"wandb": {"shared_run_id": "test123"}} + config = {"console": {"logging_mode": "per_rank_reduce"}} + + await local_fetcher.init_backends.call_one(metadata, config, global_step=5) + await local_fetcher.shutdown.call_one() + + +class TestRegistrationLifecycle: + """Test registration lifecycle.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_registration_lifecycle(self, global_logger, local_fetcher): + """Test complete registration/deregistration lifecycle.""" + proc_name = "lifecycle_test_proc" + + # Initial state + initial_count = await global_logger.get_fetcher_count.call_one() + assert await global_logger.has_fetcher.call_one(proc_name) is False + + # Register + await global_logger.register_fetcher.call_one(local_fetcher, proc_name) + + # Verify registered + new_count = await global_logger.get_fetcher_count.call_one() + assert new_count == initial_count + 1 + assert await global_logger.has_fetcher.call_one(proc_name) is True + + # Deregister + await global_logger.deregister_fetcher.call_one(proc_name) + + # Verify deregistered + final_count = await global_logger.get_fetcher_count.call_one() + assert final_count == initial_count + assert await global_logger.has_fetcher.call_one(proc_name) is False + + +class TestBackendConfiguration: + """Test backend configuration validation.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_valid_backend_configs(self, global_logger): + """Test valid backend configurations.""" + # Empty config + await global_logger.init_backends.call_one({}) + + # Valid configs for all logging modes + for mode in ["per_rank_reduce", "per_rank_no_reduce", "global_reduce"]: + config = {"console": {"logging_mode": mode}} + await global_logger.init_backends.call_one(config) + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_invalid_backend_configs(self, global_logger): + """Test invalid backend configurations raise errors.""" + from monarch.actor import ActorError + + # Missing logging_mode should work (has fallback to global_reduce) + await global_logger.init_backends.call_one({"console": {}}) + + # Invalid logging_mode should raise error (wrapped in ActorError since it's in an actor call) + with pytest.raises(ActorError): + await global_logger.init_backends.call_one( + {"console": {"logging_mode": "invalid_mode"}} + ) + + +class TestErrorHandling: + """Test error handling scenarios.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_deregister_nonexistent_fetcher(self, global_logger): + """Test deregistering non-existent fetcher doesn't crash.""" + await global_logger.deregister_fetcher.call_one("nonexistent_proc") + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_shutdown(self, global_logger): + """Test shutdown without issues.""" + await global_logger.shutdown.call_one() + + +class TestGetOrCreateMetricLogger: + """Test the integration function.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_get_or_create_functionality(self): + """Test get_or_create_metric_logger basic functionality.""" + result = await get_or_create_metric_logger() + + # Should return a GlobalLoggingActor mesh + assert result is not None + + # Should be able to call basic methods + count = await result.get_fetcher_count.call_one() + assert count >= 0 + + +class TestSynchronousLogic: + """Test synchronous logic without actor system (fastest tests).""" + + def test_all_validation_logic(self): + """COMBINED: Test all synchronous validation logic.""" + actor = GlobalLoggingActor() + + # Test 1: Valid config validation + config = {"logging_mode": "per_rank_reduce", "project": "test_project"} + result = actor._validate_backend_config("test_backend", config) + assert result["logging_mode"] == LoggingMode.PER_RANK_REDUCE + assert result["project"] == "test_project" + + # Test 2: Missing logging_mode (should work with default) + result2 = actor._validate_backend_config( + "test_backend", {"project": "test_project"} + ) + assert ( + result2["logging_mode"] == LoggingMode.GLOBAL_REDUCE + ) # Should default to global_reduce + assert result2["project"] == "test_project" + + # Test 3: Invalid logging_mode error (enum will raise ValueError) + with pytest.raises(ValueError, match="is not a valid LoggingMode"): + actor._validate_backend_config( + "test_backend", {"logging_mode": "invalid_mode"} + ) diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py new file mode 100644 index 000000000..c11467ff8 --- /dev/null +++ b/tests/unit_tests/observability/test_metrics.py @@ -0,0 +1,459 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Unit tests for core metrics functionality.""" + +import time +from unittest.mock import MagicMock, patch + +import pytest + +from forge.observability.metrics import ( + ConsoleBackend, + get_logger_backend_class, + MaxAccumulator, + MeanAccumulator, + Metric, + MetricCollector, + MinAccumulator, + record_metric, + Reduce, + reduce_metrics_states, + StdAccumulator, + SumAccumulator, + WandbBackend, +) + + +class TestMetricCreation: + """Test Metric object creation and record_metric function.""" + + def test_metric_creation_automatic_timestamp(self, mock_rank): + """Test Metric object creation with automatic timestamp.""" + before_time = time.time() + metric = Metric("test_key", 42.0, Reduce.MEAN) + after_time = time.time() + + assert metric.key == "test_key" + assert metric.value == 42.0 + assert metric.reduction == Reduce.MEAN + assert metric.timestamp is not None + assert before_time <= metric.timestamp <= after_time + + def test_metric_creation_custom_timestamp(self, mock_rank): + """Test Metric object creation with custom timestamp.""" + custom_time = 1234567890.0 + metric = Metric("test_key2", 24.0, Reduce.SUM, timestamp=custom_time) + assert metric.timestamp == custom_time + + def test_record_metric(self, mock_rank): + """Test record_metric creates correct Metric and calls collector.""" + # Mock the MetricCollector constructor to return a mock instance + mock_collector = MagicMock() + + with patch( + "forge.observability.metrics.MetricCollector", return_value=mock_collector + ): + record_metric("loss", 1.5, Reduce.MEAN) + + # Verify push was called on the mock collector + mock_collector.push.assert_called_once() + + # Verify the metric passed to push + pushed_metric = mock_collector.push.call_args[0][0] + assert pushed_metric.key == "loss" + assert pushed_metric.value == 1.5 + assert pushed_metric.reduction == Reduce.MEAN + + @patch.dict("os.environ", {"FORGE_DISABLE_METRICS": "true"}) + @patch("forge.observability.metrics.MetricCollector") + def test_record_metric_disabled(self, mock_collector_class): + """Test record_metric is no-op when FORGE_DISABLE_METRICS=true.""" + record_metric("loss", 1.5, Reduce.MEAN) + mock_collector_class.assert_not_called() + + @patch.dict("os.environ", {"FORGE_DISABLE_METRICS": "false"}) + @patch("forge.observability.metrics.MetricCollector") + def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank): + """Test record_metric works when FORGE_DISABLE_METRICS=false.""" + mock_collector = MagicMock() + mock_collector_class.return_value = mock_collector + + record_metric("loss", 1.5, Reduce.MEAN) + mock_collector_class.assert_called_once() + mock_collector.push.assert_called_once() + + +class TestAccumulators: + """Test all accumulator classes and their operations.""" + + def test_mean_accumulator(self): + """Test MeanAccumulator operations.""" + acc = MeanAccumulator(Reduce.MEAN) + + # Test initial state + assert acc.get_value() == 0.0 + state = acc.get_state() + assert state["sum"] == 0.0 + assert state["count"] == 0 + + # Test append and get_value + acc.append(10.0) + acc.append(20.0) + assert acc.get_value() == 15.0 + + # Test state + state = acc.get_state() + assert state["sum"] == 30.0 + assert state["count"] == 2 + assert state["reduction_type"] == "mean" + + # Test reset + acc.reset() + assert acc.get_value() == 0.0 + assert acc.get_state()["sum"] == 0.0 + assert acc.get_state()["count"] == 0 + + def test_sum_accumulator(self): + """Test SumAccumulator operations.""" + acc = SumAccumulator(Reduce.SUM) + + acc.append(5.0) + acc.append(3.0) + assert acc.get_value() == 8.0 + + state = acc.get_state() + assert state["total"] == 8.0 + assert state["reduction_type"] == "sum" + + acc.reset() + assert acc.get_value() == 0.0 + + def test_max_accumulator(self): + """Test MaxAccumulator operations.""" + acc = MaxAccumulator(Reduce.MAX) + + acc.append(5.0) + acc.append(10.0) + acc.append(3.0) + assert acc.get_value() == 10.0 + + state = acc.get_state() + assert state["max_val"] == 10.0 + assert state["reduction_type"] == "max" + + def test_min_accumulator(self): + """Test MinAccumulator operations.""" + acc = MinAccumulator(Reduce.MIN) + + acc.append(5.0) + acc.append(10.0) + acc.append(3.0) + assert acc.get_value() == 3.0 + + state = acc.get_state() + assert state["min_val"] == 3.0 + assert state["reduction_type"] == "min" + + def test_std_accumulator(self): + """Test StdAccumulator operations.""" + acc = StdAccumulator(Reduce.STD) + + # Test with zero/one values + assert acc.get_value() == 0.0 + acc.append(5.0) + assert acc.get_value() == 0.0 # std of single value is 0 + + # Test with multiple values + acc.append(7.0) # values: 5, 7, mean=6, std=1 + assert abs(acc.get_value() - 1.0) < 0.001 + + state = acc.get_state() + assert state["sum"] == 12.0 + assert state["sum_sq"] == 74.0 # 5^2 + 7^2 = 25 + 49 = 74 + assert state["count"] == 2 + + @pytest.mark.parametrize( + "accumulator_class,states,expected", + [ + ( + MeanAccumulator, + [ + {"reduction_type": "mean", "sum": 10.0, "count": 2}, + {"reduction_type": "mean", "sum": 20.0, "count": 3}, + ], + 6.0, # (10+20) / (2+3) + ), + ( + SumAccumulator, + [ + {"reduction_type": "sum", "total": 10.0}, + {"reduction_type": "sum", "total": 15.0}, + ], + 25.0, + ), + ], + ) + def test_accumulator_state_reduction(self, accumulator_class, states, expected): + """Test cross-accumulator state reduction.""" + result = accumulator_class.get_reduced_value_from_states(states) + assert result == expected + + def test_reduce_enum_accumulator_mapping(self): + """Test that Reduce enum correctly maps to accumulator classes.""" + assert Reduce.MEAN.accumulator_class == MeanAccumulator + assert Reduce.SUM.accumulator_class == SumAccumulator + assert Reduce.MAX.accumulator_class == MaxAccumulator + assert Reduce.MIN.accumulator_class == MinAccumulator + assert Reduce.STD.accumulator_class == StdAccumulator + + +class TestMetricCollector: + """Test MetricCollector singleton behavior and operations.""" + + def test_singleton_per_rank(self, mock_rank): + """Test MetricCollector singleton behavior per rank.""" + mock_rank.return_value.rank = 0 + collector1 = MetricCollector() + collector2 = MetricCollector() + assert collector1 is collector2 + + # Different rank should get different instance + mock_rank.return_value.rank = 1 + collector3 = MetricCollector() + assert collector1 is not collector3 + + def test_uninitialized_push_logs_warning(self, mock_rank, caplog): + """Test MetricCollector.push() logs warning when uninitialized.""" + collector = MetricCollector() + metric = Metric("test", 1.0, Reduce.MEAN) + + # just log warning and return + collector.push(metric) + assert any( + "Metric logging backends" in record.message for record in caplog.records + ) + + def test_invalid_metric_type_raises_error(self, mock_rank): + """Test MetricCollector.push() raises error for invalid metric type.""" + collector = MetricCollector() + collector._is_initialized = True + collector.per_rank_no_reduce_backends = [] + collector.per_rank_reduce_backends = [] + + with pytest.raises(TypeError, match="Expected .* object, got"): + # Type ignore because we're intentionally testing invalid input + collector.push("invalid_metric") # type: ignore + + @patch("forge.observability.metrics.get_actor_name_with_rank") + @pytest.mark.asyncio + async def test_push_and_flush(self, mock_actor_name, initialized_collector): + """Test MetricCollector push and flush with mock backends.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + collector = initialized_collector["collector"] + no_reduce_backend = initialized_collector["no_reduce_backend"] + reduce_backend = initialized_collector["reduce_backend"] + + # Test push + metric = Metric("loss", 1.5, Reduce.MEAN) + collector.push(metric) + + # Should log immediately to no_reduce backend + assert len(no_reduce_backend.immediate_metrics) == 1 + assert no_reduce_backend.immediate_metrics[0][0].key == "loss" + assert no_reduce_backend.immediate_metrics[0][1] == 0 # step + + # Should not log to reduce backend yet + assert len(reduce_backend.logged_metrics) == 0 + + # Test flush + result = await collector.flush(global_step=1, return_state=True) + + # Should have returned state + assert "loss" in result + assert result["loss"]["reduction_type"] == "mean" + assert result["loss"]["sum"] == 1.5 + assert result["loss"]["count"] == 1 + + # Should have logged to reduce backend + assert len(reduce_backend.logged_metrics) == 1 + logged_metric, global_step = reduce_backend.logged_metrics[0] + assert logged_metric.key == "loss" + assert logged_metric.value == 1.5 + assert global_step == 1 + + @pytest.mark.asyncio + async def test_flush_uninitialized_returns_empty(self, mock_rank): + """Test MetricCollector.flush() returns empty dict when uninitialized.""" + collector = MetricCollector() + result = await collector.flush(global_step=1, return_state=True) + assert result == {} + + @pytest.mark.asyncio + async def test_flush_no_metrics_returns_empty(self, mock_rank): + """Test MetricCollector.flush() returns empty dict when no metrics.""" + collector = MetricCollector() + collector._is_initialized = True + collector.per_rank_no_reduce_backends = [] + collector.per_rank_reduce_backends = [] + + result = await collector.flush(global_step=1, return_state=True) + assert result == {} + + @pytest.mark.asyncio + async def test_step_counter_for_no_reduce_backend(self, initialized_collector): + """Test step counter increments correctly for no_reduce backends.""" + collector = initialized_collector["collector"] + no_reduce_backend = initialized_collector["no_reduce_backend"] + + # Clean slate + no_reduce_backend.immediate_metrics.clear() + + # Start with step 0 + assert collector.global_step == 0 + + # Push first metric - should use current step (0) + first_metric = Metric("loss", 1.0, Reduce.MEAN) + collector.push(first_metric) + + # Verify: first metric logged with step 0 + assert len(no_reduce_backend.immediate_metrics) == 1 + first_logged_metric, first_step = no_reduce_backend.immediate_metrics[0] + assert first_logged_metric.key == "loss" + assert first_step == 0 + + # Flush at step 5 - this should increment collector.global_step to 6 + await collector.flush(global_step=5) + assert collector.global_step == 6 + + # Push second metric - should use new step (6) + second_metric = Metric("accuracy", 0.9, Reduce.MEAN) + collector.push(second_metric) + + # Verify: second metric logged with step 6 + assert len(no_reduce_backend.immediate_metrics) == 2 + second_logged_metric, second_step = no_reduce_backend.immediate_metrics[1] + assert second_logged_metric.key == "accuracy" + assert second_step == 6 + + +class TestReduceOperations: + """Test reduce_metrics_states function.""" + + def test_empty_states(self): + """Test reduce_metrics_states with empty input.""" + result = reduce_metrics_states([]) + assert result == [] + + def test_single_state(self): + """Test reduce_metrics_states with single state.""" + states = [{"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}] + result = reduce_metrics_states(states) + assert len(result) == 1 + assert result[0].key == "loss" + assert result[0].value == 5.0 + assert result[0].reduction == Reduce.MEAN + + def test_multiple_states(self): + """Test reduce_metrics_states with multiple states.""" + states = [ + {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}, + {"loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}}, + {"accuracy": {"reduction_type": "sum", "total": 15.0}}, + ] + result = reduce_metrics_states(states) + + # Convert to dict for easier testing + result_dict = {metric.key: metric.value for metric in result} + assert result_dict["loss"] == 30.0 / 5.0 # 6.0 + assert result_dict["accuracy"] == 15.0 + + # Also check reduction types + for metric in result: + if metric.key == "loss": + assert metric.reduction == Reduce.MEAN + elif metric.key == "accuracy": + assert metric.reduction == Reduce.SUM + + def test_mismatched_reduction_types_raises_error(self): + """Test reduce_metrics_states raises error for mismatched reduction types.""" + states = [ + {"loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}}, + {"loss": {"reduction_type": "sum", "total": 20.0}}, + ] + with pytest.raises(ValueError, match="Mismatched reduction types"): + reduce_metrics_states(states) + + def test_partial_key_overlap(self): + """Test reduce_metrics_states with partial key overlap.""" + states = [ + { + "loss": {"reduction_type": "mean", "sum": 10.0, "count": 2}, + "accuracy": {"reduction_type": "sum", "total": 5.0}, + }, + {"loss": {"reduction_type": "mean", "sum": 20.0, "count": 3}}, + {"throughput": {"reduction_type": "max", "max_val": 100.0}}, + ] + result = reduce_metrics_states(states) + + # Convert to dict for easier testing + result_dict = {metric.key: metric.value for metric in result} + assert result_dict["loss"] == 30.0 / 5.0 # 6.0 + assert result_dict["accuracy"] == 5.0 + assert result_dict["throughput"] == 100.0 + + +class TestBackends: + """Test backend classes and factory function.""" + + def test_backend_factory(self): + """Test get_logger_backend_class factory function.""" + assert get_logger_backend_class("console") == ConsoleBackend + assert get_logger_backend_class("wandb") == WandbBackend + + with pytest.raises(ValueError, match="Unknown logger backend type"): + get_logger_backend_class("invalid_backend") + + @patch("forge.observability.metrics.get_actor_name_with_rank") + @pytest.mark.asyncio + async def test_console_backend(self, mock_actor_name): + """Test ConsoleBackend basic operations.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + backend = ConsoleBackend({}) + + await backend.init(role="local") + + # Test log_stream + metric = Metric("test", 1.0, Reduce.MEAN) + backend.log_stream(metric, global_step=1) # Should not raise + + # Test log + await backend.log_batch([metric], global_step=1) # Should not raise + + await backend.finish() # Should not raise + + @patch("forge.observability.metrics.get_actor_name_with_rank") + @pytest.mark.asyncio + async def test_wandb_backend_creation(self, mock_actor_name): + """Test WandbBackend creation and basic setup.""" + mock_actor_name.return_value = "TestActor_abcd_r0" + + config = { + "project": "test_project", + "group": "test_group", + "logging_mode": "per_rank_reduce", + } + backend = WandbBackend(config) + + assert backend.project == "test_project" + assert backend.group == "test_group" + assert backend.per_rank_share_run is False # default + + # Test metadata method + metadata = backend.get_metadata_for_secondary_ranks() + assert metadata == {} # Should be empty when no run diff --git a/tests/unit_tests/observability/test_perf_tracker.py b/tests/unit_tests/observability/test_perf_tracker.py index 6af7331f1..7b7ba3d3d 100644 --- a/tests/unit_tests/observability/test_perf_tracker.py +++ b/tests/unit_tests/observability/test_perf_tracker.py @@ -12,7 +12,7 @@ import pytest import torch -from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_CUDA +from forge.env_constants import DISABLE_PERF_METRICS, METRIC_TIMER_USES_GPU from forge.observability.metrics import Reduce from forge.observability.perf_tracker import _TimerCPU, _TimerCUDA, trace, Tracer @@ -135,7 +135,7 @@ def test_comprehensive_workflow( if timer == "gpu" and not torch.cuda.is_available(): pytest.skip("CUDA not available") - monkeypatch.setenv(METRIC_TIMER_USES_CUDA, str(timer == "gpu")) + monkeypatch.setenv(METRIC_TIMER_USES_GPU, str(timer == "gpu")) async def run_concurrent_tasks(): start_time = time.perf_counter() @@ -276,11 +276,9 @@ def test_timer_parameter_validation(self): with pytest.raises(ValueError, match='timer must be "cpu" or "gpu"'): trace("test", timer="invalid") - # Valid values should work - tracer_cpu = Tracer("test", timer="cpu") - tracer_cuda = Tracer("test", timer="gpu") - assert tracer_cpu is not None - assert tracer_cuda is not None + # Valid values should work without errors + Tracer("test", timer="cpu") + Tracer("test", timer="gpu") def test_tracer_and_timer_reuse(self, mock_record_metric_calls): """Test both tracer and timer backends can be reused.""" @@ -370,17 +368,17 @@ async def disabled_workflow(): ("false", _TimerCPU), ], ) - def test_metric_timer_uses_cuda_override( + def test_metric_timer_uses_gpu_override( self, env_value, expected_backend, monkeypatch ): - """Test METRIC_TIMER_USES_CUDA env var overrides timer parameter.""" + """Test METRIC_TIMER_USES_GPU env var overrides timer parameter.""" if env_value == "true" and not torch.cuda.is_available(): pytest.skip("CUDA not available") with patch("torch.cuda.is_available", return_value=True), patch( "forge.observability.perf_tracker.record_metric" ): - monkeypatch.setenv(METRIC_TIMER_USES_CUDA, env_value) + monkeypatch.setenv(METRIC_TIMER_USES_GPU, env_value) # Test with timer="cpu" (should be overridden by env) tracer = Tracer("env_test", timer="cpu")