diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 04ea8efe9..7f1a65e1a 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -1,5 +1,5 @@ # Grouped Relative Policy Optimization (GRPO) -# >>> python -m apps.grpo.qwen3_1_7b --config apps/grpo/qwen3_1_7b.yaml +# >>> python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml # Global configuration group_size: 8 diff --git a/apps/toy_metrics/main.py b/apps/toy_metrics/main.py new file mode 100644 index 000000000..cd542df44 --- /dev/null +++ b/apps/toy_metrics/main.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio + +import logging +import time + +from forge.controller.actor import ForgeActor +from forge.controller.provisioner import shutdown +from forge.observability.metric_actors import setup_metric_logger +from forge.observability.metrics import record_metric, ReductionType + +from monarch.actor import current_rank, endpoint + +logging.basicConfig(level=logging.DEBUG) + + +class TrainActor(ForgeActor): + """Example training actor that records loss metrics.""" + + @endpoint + async def train_step(self, step: int): + rank = current_rank().rank + value = rank * 1000 + 100 * step + print(f"[TRAIN] Rank {rank}: Step {step}, loss={value}") + record_metric("train/loss", value) + + +class GeneratorActor(ForgeActor): + """Example generation actor that records token count metrics.""" + + @endpoint + async def generate_step(self, step: int, substep: int): + rank = current_rank().rank + value = rank * 1000 + step * 100 + substep * 10 + print(f"[GEN] Rank {rank}: Step {step}.{substep}, tokens={value}") + record_metric("generate/tokens", value, ReductionType.SUM) + + +# Main +async def main(): + """Example demonstrating distributed metric logging with different backends.""" + group = f"grpo_exp_{int(time.time())}" + + # Config format: {backend_name: backend_config_dict} + # Each backend can specify reduce_across_ranks to control distributed logging behavior + config = { + "console": {"reduce_across_ranks": True}, + "wandb": { + "project": "my_project", + "group": group, + "reduce_across_ranks": True, + # Only useful if NOT reduce_across_ranks. + "share_run_id": False, # Share run ID across ranks -- Not recommended. + }, + } + + service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} + mlogger = await setup_metric_logger() + + # Spawn services first (triggers registrations via provisioner hook) + trainer = await TrainActor.options(**service_config).as_service() + generator = await GeneratorActor.options(**service_config).as_service() + + # Now init config on global (inits backends eagerly across fetchers) + await mlogger.init_backends.call_one(config) + + for i in range(3): + print(f"\n=== Global Step {i} ===") + await trainer.train_step.fanout(i) + for sub in range(3): + await generator.generate_step.fanout(i, sub) + await mlogger.flush.call_one(i) + + # shutdown + await mlogger.shutdown.call_one() + + await asyncio.gather( + trainer.shutdown(), + generator.shutdown(), + ) + + await shutdown() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/forge/controller/__init__.py b/src/forge/controller/__init__.py index 71d35c433..8f7c2f420 100644 --- a/src/forge/controller/__init__.py +++ b/src/forge/controller/__init__.py @@ -3,7 +3,6 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - from .actor import ForgeActor from .proc_mesh import get_proc_mesh, stop_proc_mesh @@ -24,9 +23,4 @@ async def spawn_actors( return actors -__all__ = [ - "spawn_actors", - "stop_proc_mesh", - "get_proc_mesh", - "ForgeActor", -] +__all__ = ["spawn_actors", "stop_proc_mesh", "get_proc_mesh", "ForgeActor"] diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 26d51ea5c..c0670db1f 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -21,6 +21,8 @@ from monarch.tools.components import hyperactor from monarch.tools.config import Config +from forge.observability.metric_actors import setup_metric_logger + from forge.types import ProcessConfig logger = logging.getLogger(__name__) @@ -215,11 +217,19 @@ def bootstrap(gpu_ids: list[str]): self._server_names.append(server_name) self._proc_server_map[procs] = server_name + # Spawn local logging actor on each process and register with global logger + _ = await setup_metric_logger(procs) + return procs async def stop_proc_mesh(self, proc_mesh: ProcMesh): """Stops a proc mesh.""" async with self._lock: + # Deregister local logger from global logger + if hasattr(proc_mesh, "_local_fetcher"): + global_logger = await setup_metric_logger(proc_mesh) + await global_logger.deregister_fetcher.call_one(proc_mesh) + if hasattr(proc_mesh, "_gpu_ids"): gpu_manager = self._host_gpu_map[proc_mesh._host._host_id] gpu_manager.release_gpus(proc_mesh._gpu_ids) diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py new file mode 100644 index 000000000..4f630b8af --- /dev/null +++ b/src/forge/observability/__init__.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .metric_actors import GlobalLoggingActor, LocalFetcherActor, setup_metric_logger +from .metrics import ( + ConsoleBackend, + # Utility functions + get_actor_name_with_rank, + get_logger_backend_class, + # Backend classes + LoggerBackend, + MaxAccumulator, + MeanAccumulator, + # Accumulator classes + MetricAccumulator, + MetricCollector, + MinAccumulator, + record_metric, + reduce_metrics_states, + ReductionType, + StdAccumulator, + SumAccumulator, + WandbBackend, +) + +__all__ = [ + # Main API functions + "record_metric", + "reduce_metrics_states", + "get_actor_name_with_rank", + "get_logger_backend_class", + "setup_metric_logger", + # Enums + "ReductionType", + # Actor classes + "GlobalLoggingActor", + "LocalFetcherActor", + # Collector + "MetricCollector", + # Backend classes + "LoggerBackend", + "ConsoleBackend", + "WandbBackend", + # Accumulator classes + "MetricAccumulator", + "MeanAccumulator", + "SumAccumulator", + "MaxAccumulator", + "MinAccumulator", + "StdAccumulator", +] diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py new file mode 100644 index 000000000..53cca81a3 --- /dev/null +++ b/src/forge/observability/metric_actors.py @@ -0,0 +1,335 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import logging +from typing import Any, Dict, Optional + +from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc + +from forge.observability.metrics import ( + get_logger_backend_class, + MetricCollector, + reduce_metrics_states, +) + +logger = logging.getLogger(__name__) + +_global_logger = None + + +async def setup_metric_logger( + proc_mesh: ProcMesh | None = None, +) -> "GlobalLoggingActor": + """Initializes a LocalFetcherActor in the specified process mesh (or current process if None), + if not already initialized, registers it with the GlobalLoggingActor and returns the + GlobalLoggingActor instance. + + There are primarily two ways to use this function: + 1. In the main process, call `setup_metric_logger()` to get the global logger. + 2. In service processes, call `setup_metric_logger(proc_mesh)` to register the + local fetcher with the global logger. + + Args: + proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, + uses `monarch.actor.this_proc()`. + + Returns: + GlobalLoggingActor: The global logging controller. + + Raises: + ValueError: If the logging state is inconsistent, i.e. the fetcher is already + registered, but only in the process or the global logger. + + Example: + from forge.observability.metric_actors import setup_metric_logger + from forge.observability.metrics import record_metric + + # Main process setup + mlogger = await setup_metric_logger() + + # Initialize services... + policy = await Policy.as_service(...) + + # Initialize logging backends after all local fetchers are registered + # so each rank can have its own. + await mlogger.init_backends({ + "console": {"reduce_across_ranks": True}, + "wandb": {"project": "my_project", "reduce_across_ranks": False} + }) + + # Training loop + for step in range(max_steps): + record_metric("loss", 1.2, step, reduction_type=ReductionType.MEAN) + # ... training code with record_metric() calls ... + await mlogger.flush(step) # Log metrics for this step + + # Shutdown + await mlogger.shutdown() + """ + # Get or create the singleton global logger + global _global_logger + if _global_logger is None: + _global_logger = await get_or_spawn_controller( + "global_logger", GlobalLoggingActor + ) + global_logger = _global_logger + + # Determine process context + proc = proc_mesh if proc_mesh is not None else this_proc() + + # Check current state for consistency + proc_has_local_fetcher = hasattr(proc, "_local_fetcher") + global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc) + + # Consistency check: both should be in sync + if proc_has_local_fetcher != global_logger_has_local_fetcher: + raise ValueError( + f"Inconsistent logging state for proc {proc}: " + f"proc has _local_fetcher={proc_has_local_fetcher}, " + f"but global_logger has registration={global_logger_has_local_fetcher}. " + f"This indicates a bug in logging setup/teardown. " + f"Both should be True (already setup) or both False (needs setup)." + ) + + # Setup local_fetcher_actor if needed + if not proc_has_local_fetcher: + local_fetcher_actor = await proc.spawn( + "local_fetcher_actor", LocalFetcherActor, global_logger + ) + await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) + proc._local_fetcher = local_fetcher_actor + + return global_logger + + +class LocalFetcherActor(Actor): + """Thin per-process actor used to trigger MetricCollector singleton + operations without direct access. It is what GlobalLoggingActor + uses to broadcast inits/flushes across ranks. + + GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector + """ + + def __init__(self, global_logger: Optional["GlobalLoggingActor"] = None) -> None: + self.global_logger = global_logger + _is_initialized = False + + @endpoint + async def flush( + self, step: int, return_state: bool = False + ) -> Dict[str, Dict[str, Any]]: + """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. + + Args: + step (int): train step used by backends to align all metrics on the same x-axis + return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. + If False, returns empty dict, else returns the state of all metrics collected. + Returns: + Dict[str, Dict[str, Any]]: Dict of {metric_key: metric_state}, + e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. + """ + collector = MetricCollector() + result = await collector.flush(step, return_state=return_state) + return result + + @endpoint + async def init_backends( + self, + metadata_per_primary_backend: Dict[str, Dict[str, Any]], + config: Dict[str, Any], + ): + """Init local (per-rank) logger backends and MetricCollector.""" + collector = MetricCollector() + await collector.init_backends(metadata_per_primary_backend, config) + + @endpoint + async def shutdown(self): + + collector = MetricCollector() + await collector.shutdown() + + +class GlobalLoggingActor(Actor): + """Coordinates metric logging across all ranks for every training step. + + Supports multiple logging backends (e.g., WandB, TensorBoard, etc.), + for per-rank and/or global reduction logging modes. + + If a backend config has flag `reduce_across_ranks=False`, an instance of the backend + is initialized per-rank, otherwise it is done once globally. + + This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor + is automatically spawned per-rank in `forge.controller.provisioner.py` and registered + with this actor. The LocalFetcherActor is responsible for instantiating + the per-rank MetricCollector. + + In summary, the flow is: + - GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector + - GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush + """ + + def __init__(self): + self.fetchers: Dict[str, LocalFetcherActor] = {} + self.config: Dict[str, Any] | None = None + self.global_logger_backends: Dict[str, "LoggerBackend"] = {} + self.metadata_per_primary_backend: Dict[str, Dict[str, Any]] = {} + + @endpoint + async def init_backends(self, config: Dict[str, Any]): + """ + Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors + in all registered fetchers. + + A backend is always initialized in the controller (primary backend) and can be used as a logger or as a source + for metadata to be shared with per-rank backends, e.g. shared run IDs for wandb. + + The backend instantiation is controlled by the backend config flag `reduce_across_ranks`: if False, + a per-rank backend is initialized, i.e. if there are 2 ranks, each will have its own backend, + and will log independently, i.e. each rank will have its own run in wandb. + + Else, if True, the GlobalLoggingActor will fetch all local metrics collectors to get their states + and reduce them to a single value, which will be logged by the primary backend in this controller. + + Args: + config (Dict[str, Any]): Config for metric logging where keys are backend names, + e.g. {"console": {"reduce_across_ranks": True}, "wandb": {"reduce_across_ranks": False}} + """ + self.config = config + + for backend_name, backend_config in config.items(): + backend = get_logger_backend_class(backend_name)(backend_config) + await backend.init(role="global") + + # Extract metadata from primary logger to be shared with secondary loggers + # and store it + reduce_across_ranks = backend_config.get("reduce_across_ranks", True) + if not reduce_across_ranks: + primary_backend_metadata = ( + backend.get_metadata_for_secondary_ranks() or {} + ) + self.metadata_per_primary_backend[ + backend_name + ] = primary_backend_metadata + + # Store global logger backends + if reduce_across_ranks: + self.global_logger_backends[backend_name] = backend + + # Eager init collectors on all registered fetchers in parallel, passing primary states and config + if self.fetchers: + tasks = [ + fetcher.init_backends.call( + self.metadata_per_primary_backend, self.config + ) + for fetcher in self.fetchers.values() + ] + await asyncio.gather(*tasks, return_exceptions=True) + + @endpoint + async def register_fetcher(self, fetcher: LocalFetcherActor, name: str | ProcMesh): + """Registers a fetcher with the global actor. Each key represents a process mesh. + If there are 2 processes, each with 2 replicas with N gpus, we would + have 4 keys, i.e. 2 proces meshes, each with 2 replicas.""" + self.fetchers[name] = fetcher + + # Self-init for respawned actors + if self.config: + logger.debug(f"Initializing new LocalFetcherActor {name}") + await fetcher.init_backends.call( + self.metadata_per_primary_backend, self.config + ) + + @endpoint + async def deregister_fetcher(self, name: str | ProcMesh): + if name not in self.fetchers: + logger.warning( + f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister." + f"Available fetchers: {self.fetchers.keys()}" + ) + return + del self.fetchers[name] + + @endpoint + async def flush(self, step: int): + """ + Triggers parallel flush/reset on all registered fetchers. Per-rank MetricCollectors + log to local backends and return states if needed for cross-rank reduction. + + Args: + step (int): Global step for logging. + """ + if not self.fetchers: + return + + config = self.config + # if reduce_across_ranks=True, we need to reduce the states from all ranks + # and log with the primary backend + requires_reduce = any( + backend_config.get("reduce_across_ranks", True) + for backend_config in config.values() + ) + + logger.debug(f"Global flush for step {step}: {len(self.fetchers)} fetchers") + + # Broadcast flush to all fetchers + results = await asyncio.gather( + *[ + f.flush.call(step, return_state=requires_reduce) + for f in self.fetchers.values() + ], + return_exceptions=True, + ) + + if requires_reduce: + # Handle exceptions and extract values from ValueMesh results + all_local_states = [] + for result in results: + if isinstance(result, Exception): + logger.warning(f"Flush failed on a fetcher: {result}") + continue + + # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}] + for gpu_info, local_metric_state in result.items(): + if isinstance(local_metric_state, dict): + all_local_states.append(local_metric_state) + else: + logger.warning( + f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}" + ) + + if not all_local_states: + logger.warning(f"No states to reduce for step {step}") + return + + # Reduce + reduced_metrics = reduce_metrics_states(all_local_states) + + # Log to each global logger_backend + for ( + logger_backend_name, + logger_backend, + ) in self.global_logger_backends.items(): + await logger_backend.log(reduced_metrics, step) + + @endpoint + def has_fetcher(self, name: str | ProcMesh) -> bool: + """Check if a fetcher is registered with the given name.""" + return name in self.fetchers + + @endpoint + def get_fetcher_count(self) -> int: + return len(self.fetchers) + + @endpoint + async def shutdown(self): + # Finish per-rank logger_backends via fetchers + if self.fetchers: + tasks = [fetcher.shutdown.call() for fetcher in self.fetchers.values()] + await asyncio.gather(*tasks, return_exceptions=True) + # Finish global logger_backends + for logger_backend_name, logger_backend in self.global_logger_backends.items(): + await logger_backend.finish() diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py new file mode 100644 index 000000000..2f2b70494 --- /dev/null +++ b/src/forge/observability/metrics.py @@ -0,0 +1,690 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from abc import ABC, abstractmethod +from enum import Enum +from typing import Any, Dict, List, Optional + +from monarch.actor import context, current_rank + + +logger = logging.getLogger(__name__) + + +class ReductionType(Enum): + MEAN = "mean" + SUM = "sum" + MAX = "max" + MIN = "min" + STD = "std" + + @property + def accumulator_class(self): + mapping = { + ReductionType.MEAN: MeanAccumulator, + ReductionType.SUM: SumAccumulator, + ReductionType.MAX: MaxAccumulator, + ReductionType.MIN: MinAccumulator, + ReductionType.STD: StdAccumulator, + } + return mapping[self] + + +def get_actor_name_with_rank() -> str: + """ + Extracts actor information from Monarch context to form a logging name. + + Returns: + str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). + Falls back to "UnknownActor" if context unavailable. + """ + # Add more defensive checks + ctx = context() + if ctx is None or ctx.actor_instance is None: + logger.warning("Context unavailable, using fallback actor name for logging.") + return "UnknownActor" + + actor_instance = ctx.actor_instance + rank = current_rank() + + actor_id_full = str(actor_instance.actor_id) + + # Parse the actor_id + parts = actor_id_full.split(".") + rank_name = "UnknownActor" # fallback + if len(parts) >= 2: + world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" + actor_part = parts[1] # e.g., "TestActorConfigured[0]" + + # Extract world ID and proc rank + world_id = world_part.split("[")[0] if "[" in world_part else world_part + + # Extract clean actor name (remove "Configured" suffix if present) + if "[" in actor_part: + actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" + if actor_name.endswith("Configured"): + actor_name = actor_name[:-10] # Remove "Configured" + else: + actor_name = actor_part + + # Use last 4 characters of world_id as replica identifier + # This is deterministic, readable, and works for any number of replicas + replica_id = world_id[-4:] if len(world_id) >= 4 else world_id + + # Use current_rank().rank as the local rank within the replica + local_rank = rank.rank + + rank_name = f"{actor_name}_{replica_id}_r{local_rank}" + + return rank_name + + +def record_metric( + key: str, value: Any, reduction: ReductionType = ReductionType.MEAN +) -> None: + """ + Records a metric value for later reduction and logging. + + Relies on a per-rank MetricCollector singleton for ease of use, i.e. + call `record_metric` anywhere in the code without moving the + collector from function to function. + + The collector methods are triggered per-rank by a + `forge.observability.metric_actors.LocalFetcherActor`, instantiated + during actor initialization. + + Records are flushed when `forge.observability.metric_actors.GlobalLoggingActor.flush()` + is called, typically triggered by the training loop at regular intervals. + """ + collector = MetricCollector() + collector.push(key, value, reduction) + + +def reduce_metrics_states(states: List[Dict[str, Dict[str, Any]]]) -> Dict[str, Any]: + """Reduce metric accumulators states to a single value per metric. + + Can be used when reducing metrics across ranks or services, as merging + states is more precise than merging locally reduced metrics. + + Args: + states (List[Dict[str, Dict[str, Any]]]): List of state of one or more metrics, + normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. + + Returns: + Dict[str, Any]: Dictionary with format {metric_key: reduced_value} + + Example: + states = [ + {"loss": {"count": 5, "sum": 14, "reduction_type": ReductionType.MEAN}}, + {"loss": {"count": 10, "sum": 16, "reduction_type": ReductionType.MEAN}}, + ] + reduce_metrics_states(states) + >>> {"loss": 2.0} + + Raises: + ValueError: on mismatched reduction types for the same metric key. + """ + if not states: + return {} + + # Collect unique keys across all + all_keys = set(k for state in states for k in state) + + reduced_metrics = {} + for key in all_keys: + metric_states = [state.get(key) for state in states if key in state] + if not metric_states: + continue + + first_reduction_type = metric_states[0]["reduction_type"] + + # Check consistency + for state in metric_states: + if state["reduction_type"] != first_reduction_type: + raise ValueError( + f"Mismatched reduction types for key '{key}': {first_reduction_type} vs {state['reduction_type']}" + ) + + metric_accumulator = ReductionType(first_reduction_type).accumulator_class + reduced_value = metric_accumulator.get_reduced_value_from_states(metric_states) + reduced_metrics[key] = reduced_value + + return reduced_metrics + + +################ +# Accumulators # +################ + + +class MetricAccumulator(ABC): + """Every metric maps to a MetricAccumulator, which accumulates values and optionally reduces them.""" + + def __init__(self, reduction: ReductionType): + self.reduction_type = reduction + + @abstractmethod + def append(self, value: Any) -> None: + """Updates accumulator with new value (e.g., adds to sum and count for MEAN).""" + pass + + @abstractmethod + def get_value(self) -> Any: + """Returns locally reduced value (e.g., sum/count for MEAN).""" + pass + + @abstractmethod + def get_state(self) -> Dict[str, Any]: + """Returns serializable state for cross-rank merge (e.g., {'sum': 10.0, 'count': 5}).""" + pass + + @classmethod + @abstractmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> Any: + """Merges states from multiple ranks into single reduced value (e.g., total_sum/total_count for MEAN).""" + pass + + @abstractmethod + def reset(self) -> None: + """Clears for next accumulation cycle (e.g., sum=0, count=0 for MEAN).""" + pass + + +class MeanAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.sum = 0.0 + self.count = 0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.sum += v + self.count += 1 + + def get_value(self) -> float: + return self.sum / self.count if self.count > 0 else 0.0 + + def get_state(self) -> Dict[str, Any]: + return { + "reduction_type": self.reduction_type.value, + "sum": self.sum, + "count": self.count, + } + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + total_sum = sum(s["sum"] for s in states) + total_count = sum(s["count"] for s in states) + return total_sum / total_count if total_count > 0 else 0.0 + + def reset(self) -> None: + self.sum = 0.0 + self.count = 0 + + +class SumAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.total = 0.0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.total += v + + def get_value(self) -> float: + return self.total + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "total": self.total} + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + return sum(s["total"] for s in states) + + def reset(self) -> None: + self.total = 0.0 + + +class MaxAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.max_val = float("-inf") + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.max_val = max(self.max_val, v) + + def get_value(self) -> float: + return self.max_val + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "max_val": self.max_val} + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + return max(s["max_val"] for s in states) + + def reset(self) -> None: + self.max_val = float("-inf") + + +class MinAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.min_val = float("inf") + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.min_val = min(self.min_val, v) + + def get_value(self) -> float: + return self.min_val + + def get_state(self) -> Dict[str, Any]: + return {"reduction_type": self.reduction_type.value, "min_val": self.min_val} + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + return min(s["min_val"] for s in states) + + def reset(self) -> None: + self.min_val = float("inf") + + +class StdAccumulator(MetricAccumulator): + def __init__(self, reduction: ReductionType): + super().__init__(reduction) + self.sum = 0.0 + self.sum_sq = 0.0 + self.count = 0 + + def append(self, value: Any) -> None: + v = float(value.item() if hasattr(value, "item") else value) + self.sum += v + self.sum_sq += v * v + self.count += 1 + + def get_value(self) -> float: + if self.count == 0: + return 0.0 + if self.count == 1: + return 0.0 + mean = self.sum / self.count + variance = (self.sum_sq / self.count) - (mean * mean) + return max(0.0, variance) ** 0.5 + + def get_state(self) -> Dict[str, Any]: + return { + "reduction_type": self.reduction_type.value, + "sum": self.sum, + "sum_sq": self.sum_sq, + "count": self.count, + } + + @classmethod + def get_reduced_value_from_states(cls, states: List[Dict[str, Any]]) -> float: + total_sum = sum(s["sum"] for s in states) + total_sum_sq = sum(s["sum_sq"] for s in states) + total_count = sum(s["count"] for s in states) + if total_count == 0: + return 0.0 + if total_count == 1: + return 0.0 + mean = total_sum / total_count + variance = (total_sum_sq / total_count) - (mean * mean) + return max(0.0, variance) ** 0.5 + + def reset(self) -> None: + self.sum = 0.0 + self.sum_sq = 0.0 + self.count = 0 + + +############# +# Collector # +############# + + +class MetricCollector: + """Per-rank singleton for accumulating, retrieving and flushing metrics to backends. + + A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, + the backend is instantiated per-rank, in the MetricCollector, otherwise it is instantiated once globally, + in the GlobalLoggingActor. + + - Ensures one instance per process; actors call record_metric() which delegates here. + - Init via GlobalLoggingActor -> LocalFetcherActor -> per-rank MetricCollector; + - GlobalLoggingActor flushes trigger reductions and log for any locally setup backend. Can optionally also + return non-reduced states for global aggregation. This can be different for each backend. + - Resets accumulators post-flush to avoid leaks across train steps; + """ + + _instances: Dict[int, "MetricCollector"] = {} + + def __new__(cls): + """Singleton per-rank, ensures one instance per process.""" + rank = current_rank().rank + + if rank not in cls._instances: + inst = super().__new__(cls) + cls._instances[rank] = inst + inst._singleton_rank = rank + else: + inst = cls._instances[rank] + if inst._singleton_rank != rank: + raise ValueError( + f"Singleton expected rank {inst._singleton_rank}, but saw {rank}" + ) + return inst + + def __init__(self): + if hasattr(self, "_is_initialized"): + return + + self.accumulators: Dict[str, MetricAccumulator] = {} + self.rank = current_rank().rank + self.logger_backends: List[LoggerBackend] = [] + self._is_initialized = False + + async def init_backends( + self, + metadata_per_primary_backend: Optional[Dict[str, Dict[str, Any]]], + config: Dict[str, Any], + ) -> None: + """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, + the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated + once globally. + + Args: + metadata_per_primary_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from primary + logger backend, e.g., {"wandb": {"run_id": "abc123"}}. + config (Dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. + """ + if self._is_initialized: + logger.debug(f"Rank {self.rank}: MetricCollector already initialized") + return + + # instantiate local backends if any + for backend_name, backend_config in config.items(): + if backend_config.get("reduce_across_ranks", True): + continue # Skip local backend instantiation and use global instead + primary_state = metadata_per_primary_backend.get(backend_name, {}) + logger_backend = get_logger_backend_class(backend_name)(backend_config) + await logger_backend.init( + role="local", primary_logger_metadata=primary_state + ) + self.logger_backends.append(logger_backend) + + self._is_initialized = True + + def push( + self, key: str, value: Any, reduction: ReductionType = ReductionType.MEAN + ) -> None: + if not self._is_initialized: + raise ValueError("Collector not initialized—call init first") + + if key not in self.accumulators: + self.accumulators[key] = reduction.accumulator_class(reduction) + + self.accumulators[key].append(value) + + async def flush( + self, step: int, return_state: bool = False + ) -> Dict[str, Dict[str, Any]]: + """Log to local logger backends (if any), reset accumulators and return metric states dict if return_state=True. + + Args: + step (int): Step used by backends to align metrics on the same x-axis + return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. + If False, returns empty dict, else returns the state of all metrics collected. + Returns: + Dict[str, Dict[str, Dict[str, Any]]]: Dict of {metric_key: metric_state}, + e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. + """ + if not self._is_initialized: + logger.debug( + f"Collector not yet initialized for {get_actor_name_with_rank()}. Call init_backends first." + ) + return {} + + if not self.accumulators: + logger.debug( + f"Collector rank {get_actor_name_with_rank()}: No metrics to flush for step {step}" + ) + return {} + + # Snapshot states and reset immediately + states = {} + for key, acc in self.accumulators.items(): + states[key] = acc.get_state() + acc.reset() + + # Reduce metrics from states for logging if any per-rank backend + if self.logger_backends: + metrics = {} + for key, state in states.items(): + acc_class = ReductionType(state["reduction_type"]).accumulator_class + metrics[key] = acc_class.get_reduced_value_from_states([state]) + + # Log to local logger_backends + for logger_backend in self.logger_backends: + await logger_backend.log(metrics, step) + + return states if return_state else {} + + async def shutdown(self): + """Shutdown logger_backends if initialized.""" + if not self._is_initialized: + logger.debug( + f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" + ) + return + + for logger_backend in self.logger_backends: + await logger_backend.finish() + + +########### +# Backends # +########### + + +class LoggerBackend(ABC): + """Abstract logger_backend for metric logging, e.g. wandb, jsonl, etc.""" + + def __init__(self, logger_backend_config: Dict[str, Any]): + self.logger_backend_config = logger_backend_config + + @abstractmethod + async def init( + self, + role: str, + primary_logger_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Initializes backend, e.g. wandb.run.init(). + + Args: + role (str): "global" (controller/primary) or "local" (per-rank/secondary). + Can be used to behave differently for primary vs secondary roles. + primary_logger_metadata (Optional[Dict[str, Any]]): From global backend for + backend that required shared info, e.g. {"shared_run_id": "abc123"}. + + Raises: ValueError if missing metadata for shared local init. + """ + if primary_logger_metadata is None: + primary_logger_metadata = {} + pass + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + pass + + async def finish(self) -> None: + pass + + def get_metadata_for_secondary_ranks(self) -> Optional[Dict[str, Any]]: + """Return sharable state after primary init (e.g., for shared modes). Called only on globals.""" + return None + + +class ConsoleBackend(LoggerBackend): + """Simple console logging of metrics.""" + + def __init__(self, logger_backend_config: Dict[str, Any]): + super().__init__(logger_backend_config) + + async def init( + self, + role: str, + primary_logger_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + self.prefix = ( + get_actor_name_with_rank() + if self.logger_backend_config.get("reduce_across_ranks", True) + else "GLOBAL" + ) + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + logger.info(f"=== [{self.prefix}] - METRICS STEP {step} ===") + for key, value in metrics.items(): + logger.info(f" {key}: {value}") + logger.info("==============================\n") + + async def finish(self) -> None: + pass + + +class WandbBackend(LoggerBackend): + """ + Weights & Biases logging backend for distributed training. + + Supports 3 types of modes as described in https://docs.wandb.ai/guides/track/log/distributed-training/: + Track a single process: reduce_across_ranks=True + Track each process separately: reduce_across_ranks=False, share_run_id=False + Track all processes to a single run: reduce_across_ranks=False, share_run_id=True + + Configuration: + reduce_across_ranks (bool, default True): If True, log reduced metrics only from controller (global mode). + If False, enables per-rank logging; then use share_run_id to pick mode. + share_run_id (bool, default False): Only used if reduce_across_ranks=False. + True -> shared run across ranks; False -> separate runs per rank. + project (str): WandB project name + group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" + """ + + def __init__(self, logger_backend_config: Dict[str, Any]): + super().__init__(logger_backend_config) + self.project = logger_backend_config["project"] + self.group = logger_backend_config.get("group", "experiment_group") + self.name = None + self.run = None + self.reduce_across_ranks = logger_backend_config.get( + "reduce_across_ranks", True + ) + self.share_run_id = logger_backend_config.get("share_run_id", False) + + async def init( + self, + role: str, + primary_logger_metadata: Optional[Dict[str, Any]] = None, + ) -> None: + + if primary_logger_metadata is None: + primary_logger_metadata = {} + + if role not in ["global", "local"]: + raise ValueError( + f"Invalid role {role} for WandbBackend init. Must be 'global' or 'local'." + ) + + self.name = ( + get_actor_name_with_rank() if role == "local" else "global_controller" + ) + + # Default global mode: only inits on controller + if self.reduce_across_ranks: + if role != "global": + logger.debug( + f"Skipped init for global mode (reduce_across_ranks=True) and {role} role." + ) + return + await self._init_global() + + # Per-rank modes based on share_run_id bool + elif role == "global" and self.share_run_id: + await self._init_shared_global() + + elif role == "local": + if self.share_run_id: + await self._init_shared_local(primary_logger_metadata) + else: + await self._init_per_rank() + + async def _init_global(self): + import wandb + + self.run = wandb.init(project=self.project, group=self.group) + + async def _init_per_rank(self): + import wandb + + self.run = wandb.init(project=self.project, group=self.group, name=self.name) + + async def _init_shared_global(self): + import wandb + + settings = wandb.Settings( + mode="shared", x_primary=True, x_label="controller_primary" + ) + self.run = wandb.init(project=self.project, group=self.group, settings=settings) + + async def _init_shared_local(self, primary_metadata: Dict[str, Any]): + import wandb + + shared_id = primary_metadata.get("shared_run_id") + if shared_id is None: + raise ValueError( + f"Shared ID required but not provided for {self.name} backend init" + ) + settings = wandb.Settings(mode="shared", x_primary=False, x_label=self.name) + self.run = wandb.init( + id=shared_id, + project=self.project, + group=self.group, + settings=settings, + ) + + async def log(self, metrics: Dict[str, Any], step: int) -> None: + if self.run: + log_data = {**metrics, "global_step": step} + self.run.log(log_data) + logger.info(f"WandbBackend: Logged {len(metrics)} metrics at step {step}") + else: + logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") + + def get_metadata_for_secondary_ranks(self) -> Dict[str, Any]: + if self.run and not self.reduce_across_ranks and self.share_run_id: + return {"shared_run_id": self.run.id} + return {} + + async def finish(self) -> None: + if self.run: + self.run.finish() + logger.info(f"WandbBackend {self.name}: Finished run") + + +def get_logger_backend_class(cls_name: str) -> type[LoggerBackend]: + """Simple mapping between logger_backend type and its class + + Factory for backend classes from config; returns uninitialized class for role-based init. + """ + if cls_name == "console": + return ConsoleBackend + elif cls_name == "wandb": + return WandbBackend + else: + raise ValueError(f"Unknown logger backend type: {cls_name}")