diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 14e4871cf..a9acd268d 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -16,11 +16,12 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + per_rank_share_run: False console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 593a2e1fb..14dcf7629 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -19,11 +19,12 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + per_rank_share_run: False console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 9b2f70edd..f5367ce4b 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -12,11 +12,12 @@ off_by_n: 1 # Off by one by default # Observability configuration metric_logging: wandb: - project: "grpo-training" - group: "grpo_exp_${oc.env:USER}" - reduce_across_ranks: True + project: grpo-training + group: grpo_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + per_rank_share_run: False console: - reduce_across_ranks: True + logging_mode: global_reduce # Dataset configuration dataset: diff --git a/apps/sft_v2/llama3_8b.yaml b/apps/sft_v2/llama3_8b.yaml index 273d2d592..4a8712efa 100644 --- a/apps/sft_v2/llama3_8b.yaml +++ b/apps/sft_v2/llama3_8b.yaml @@ -55,6 +55,13 @@ activation_checkpoint: mode: selective selective_ac_option: op +metric_logging: + wandb: + project: sft-training + group: sft_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + per_rank_share_run: False + # profiling: # enable_profiling: false diff --git a/apps/sft_v2/main.py b/apps/sft_v2/main.py index 61b27baa3..132340359 100644 --- a/apps/sft_v2/main.py +++ b/apps/sft_v2/main.py @@ -16,6 +16,7 @@ import math import os import sys +import warnings from functools import partial from typing import Any @@ -28,6 +29,7 @@ from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer +from forge.observability import get_or_create_metric_logger, record_metric, Reduce from monarch.actor import current_rank, current_size, endpoint from omegaconf import DictConfig, OmegaConf @@ -109,9 +111,20 @@ def _init_dist(self): os.environ.update(env) logger.info("env: {}".format(env)) + async def setup_metric_logger(self): + """Initialization happens in the main process. Here we just retrieve it""" + mlogger = await get_or_create_metric_logger() + return mlogger + + def record_batch_metrics(self, data_metrics: list): + """Record dataset metrics using the observability system.""" + for metric in data_metrics: + record_metric(metric.key, metric.value, metric.reduction) + @endpoint async def setup(self): self.train_dataloader = self.setup_data() + self.mlogger = await self.setup_metric_logger() # self.train_dataloader = self.setup_data( # self.train_config.train_dataset_config, # self.train_config.train_dataloader_config, @@ -235,6 +248,7 @@ def train_step(self, batch) -> None: labels = batch.pop("labels") loss = self.forward_backward(batch, labels) + record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN) logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") # self.pbar.update(1) @@ -251,14 +265,27 @@ async def train(self) -> None: while self.current_step < self.num_training_steps: batch = next(dataloader) + + # Pop and record metrics from batch before moving to device + self.record_batch_metrics(batch.pop("metrics", [])) + record_metric( + "ForgeSFTRecipe/train_step/step", self.current_step, Reduce.MEAN + ) + # Move tensors to the appropriate device for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to("cuda") # TODO: hardcoded for now + self.train_step(batch) # self.profiler.step() self.current_step += 1 + # Flush metrics + if self._rank == 0: + logger.debug(f"Flushing metrics at step {self.current_step}") + await self.mlogger.flush.call_one(global_step=self.current_step) + self.checkpointer.save( curr_step=self.current_step, last_step=self.current_step == self.num_training_steps, @@ -270,16 +297,35 @@ async def train(self) -> None: async def cleanup(self) -> None: if self.checkpointer: self.checkpointer.close() - if self.metric_logger: - self.metric_logger.close() + if hasattr(self, "mlogger") and self.mlogger: + await self.mlogger.shutdown.call_one() def __repr__(self) -> str: return "Trainer" async def run(cfg: DictConfig) -> None: - logging.info("Spawing recipe...") + + # TODO (allenwang28) Required for metric logging to work. Should be removed when V1 becomes default + MONARCH_HOSTMESH_V1 = os.getenv("MONARCH_HOSTMESH_V1") + if MONARCH_HOSTMESH_V1 != "1": + warnings.warn( + "MONARCH_HOSTMESH_V1 is set to {MONARCH_HOSTMESH_V1}. Setting it to '1' for SFT v2 to work properly. ", + UserWarning, + stacklevel=2, + ) + os.environ["MONARCH_HOSTMESH_V1"] = "1" + + logging.info("Spawning recipe...") process_cfg = cfg.pop("processes") + + # Initialize metric logger in main process + metric_logging_cfg = cfg.get( + "metric_logging", {"console": {"logging_mode": "global_reduce"}} + ) + mlogger = await get_or_create_metric_logger(process_name="Controller") + await mlogger.init_backends.call_one(metric_logging_cfg) + recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg) logging.info("Created recipe, running setup.") @@ -290,6 +336,9 @@ async def run(cfg: DictConfig) -> None: logging.info("Done training. Clean up") await recipe.cleanup.call() + + # Shutdown metric logger + await mlogger.shutdown.call_one() await recipe.mesh.stop() logging.info("All done!") diff --git a/src/forge/data/__init__.py b/src/forge/data/__init__.py index 4347199b9..74ba663e0 100644 --- a/src/forge/data/__init__.py +++ b/src/forge/data/__init__.py @@ -4,6 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from .collate import collate_packed +from .metric_transform import DefaultDatasetMetricTransform, MetricTransform from .utils import CROSS_ENTROPY_IGNORE_IDX -__all__ = ["collate_packed", "CROSS_ENTROPY_IGNORE_IDX"] +__all__ = [ + "collate_packed", + "CROSS_ENTROPY_IGNORE_IDX", + "MetricTransform", + "DefaultDatasetMetricTransform", +] diff --git a/src/forge/data/dataset_metrics/__init__.py b/src/forge/data/dataset_metrics/__init__.py deleted file mode 100644 index 3a218e282..000000000 --- a/src/forge/data/dataset_metrics/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from .metric_agg_handlers import ( - AggregationHandler, - CategoricalCountAggHandler, - MaxAggHandler, - MeanAggHandler, - MetricState, - MinAggHandler, - StatsAggHandler, - SumAggHandler, -) -from .metric_aggregator import MetricsAggregator -from .metric_transform import ( - AggregationType, - DefaultTrainingMetricTransform, - Metric, - MetricTransform, -) - -__all__ = [ - "AggregationType", - "AggregationHandler", - "CategoricalCountAggHandler", - "DefaultTrainingMetricTransform", - "StatsAggHandler", - "MaxAggHandler", - "MeanAggHandler", - "Metric", - "MetricState", - "MetricsAggregator", - "MetricTransform", - "MinAggHandler", - "SumAggHandler", -] diff --git a/src/forge/data/dataset_metrics/metric_agg_handlers.py b/src/forge/data/dataset_metrics/metric_agg_handlers.py deleted file mode 100644 index bb3978a6b..000000000 --- a/src/forge/data/dataset_metrics/metric_agg_handlers.py +++ /dev/null @@ -1,466 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from abc import ABC, abstractmethod -from collections import Counter, deque -from dataclasses import dataclass, field -from typing import Any - -import torch - -from .metric_transform import AggregationType, Metric - -logger = logging.getLogger(__name__) - - -@dataclass -class MetricState: - """Mutable state object representing the state of a (source, metric_name) on a single rank. - - Attributes: - source (str): Name of the source, e.g. the dataset name. Used for logging and disambiguation. - metric_name (str): Name of the metric. - value (float): Current aggregated value, whose meaning depends on the aggregation type - (e.g., running sum, current max). - agg_type (AggregationType): Aggregation type. - metadata (dict[str, Any]): Additional state like count, list of values, etc. - """ - - source: str - metric_name: str - value: float - agg_type: AggregationType - metadata: dict[str, Any] = field(default_factory=dict) - - -class AggregationHandler(ABC): - """Base class for handling metric aggregation. - - Each handler implements a specific aggregation strategy (SUM, MEAN, STATS, etc.) - and manages the complete lifecycle: initialization, updates, local finalization, - and distributed reduction. Handlers also handle serialization for checkpointing. - - The handler architecture allows pluggable aggregation strategies while maintaining - consistent interfaces for the MetricsAggregator. - """ - - @abstractmethod - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - """Create a new MetricState for a (source, metric_name) pair. - - Args: - source (str): Name of the source, e.g. the dataset name. Used for logging and disambiguation. - metric_name (str): Name of the metric. - agg_type (AggregationType): Aggregation type. - - Returns: - MetricState: New MetricState for this (source, metric_name) pair. - """ - pass - - @abstractmethod - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - """Update cumulative MetricState with new metric info. - - Args: - local_agg_metric (MetricState): State of the aggregation for this metric in the local rank. - metric (Metric): Input metric info. - """ - pass - - @abstractmethod - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - """ - Computes the final value from the locally aggregated state. For example, for mean - it would mean to divide the tracked sum by the tracked count. - - This method may expand a single metric into multiple, for instance, - a list of numbers into mean, min, max, and percentiles. - - Args: - local_agg_metric (MetricState): The locally aggregated metric state to finalize. - - Returns: - list[MetricState]: List of finalized metric states. - """ - pass - - @abstractmethod - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - """ - Merge MetricStates from all ranks into final result. For example, for 'sum', it would mean to - sum the values from all ranks. - - Args: - local_agg_metrics (list[MetricState]): list of MetricStates from all ranks for a specific - (source, metric_name) tuple after computing finalize_local_agg. - - Returns: - MetricState: Final result for this (source, metric_name) pair. - """ - pass - - def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Convert handler-specific metadata to serializable format. Override this when using - non-serializable types like deque or Counter. For example, convert deque to list, Counter to dict. - - Args: - metadata (dict[str, Any]): AggHandler-specific metadata. - - Returns: - dict[str, Any]: Serializable metadata. - """ - return metadata.copy() - - def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Restore handler-specific metadata from serialized format. Override this to reverse the - serialize_metadata transformation. For example, convert list back to deque, dict back to Counter. - - Args: - metadata (dict[str, Any]): AggHandler-specific metadata. - - Returns: - dict[str, Any]: Deserialized metadata. - """ - return metadata.copy() - - -class SumAggHandler(AggregationHandler): - """AggHandler for SUM aggregation. Initializes with 0.0 and accumulates metric values.""" - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=0.0, - agg_type=agg_type, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - if not isinstance(metric.value, (int, float)): - raise ValueError( - f"SumAggHandler expects numeric values, got {type(metric.value)}" - ) - local_agg_metric.value += metric.value - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - return [local_agg_metric] - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - if not local_agg_metrics: - raise ValueError("Cannot aggregate empty list of metrics") - - total = sum(metric.value for metric in local_agg_metrics) - return MetricState( - source=local_agg_metrics[0].source, - metric_name=local_agg_metrics[0].metric_name, - value=total, - agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy(), - ) - - -class MaxAggHandler(AggregationHandler): - """AggHandler for MAX aggregation. Tracks maximum value across all updates.""" - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=float("-inf"), - agg_type=agg_type, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - if not isinstance(metric.value, (int, float)): - raise ValueError( - f"MaxAggHandler expects numeric values, got {type(metric.value)}" - ) - local_agg_metric.value = max(local_agg_metric.value, metric.value) - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - return [local_agg_metric] - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - max_value = max(r.value for r in local_agg_metrics) - return MetricState( - source=local_agg_metrics[0].source, - metric_name=local_agg_metrics[0].metric_name, - value=max_value, - agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy(), - ) - - -class MinAggHandler(AggregationHandler): - """AggHandler for MIN aggregation. Tracks minimum value across all updates.""" - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=float("inf"), - agg_type=agg_type, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - if not isinstance(metric.value, (int, float)): - raise ValueError( - f"MinAggHandler expects numeric values, got {type(metric.value)}" - ) - local_agg_metric.value = min(local_agg_metric.value, metric.value) - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - return [local_agg_metric] - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - min_value = min(r.value for r in local_agg_metrics) - return MetricState( - source=local_agg_metrics[0].source, - metric_name=local_agg_metrics[0].metric_name, - value=min_value, - agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy(), - ) - - -class MeanAggHandler(AggregationHandler): - """AggHandler for MEAN aggregation. Maintains running sum and count to compute average.""" - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=0.0, - agg_type=agg_type, - metadata={"sum": 0.0, "count": 0}, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - local_agg_metric.metadata["sum"] += metric.value - local_agg_metric.metadata["count"] += 1 - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - count = local_agg_metric.metadata["count"] - local_agg_metric.value = ( - local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 - ) - return [local_agg_metric] - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - total_sum = sum(metric.metadata["sum"] for metric in local_agg_metrics) - total_count = sum(metric.metadata["count"] for metric in local_agg_metrics) - - return MetricState( - source=local_agg_metrics[0].source, - metric_name=local_agg_metrics[0].metric_name, - value=total_sum / total_count if total_count > 0 else 0.0, - agg_type=local_agg_metrics[0].agg_type, - metadata={"sum": total_sum, "count": total_count}, - ) - - -class StatsAggHandler(AggregationHandler): - """AggHandler for STATS aggregation. Maintains a sliding window of values - and expands into multiple statistical metrics (mean, min, max, percentiles, std). - - Note: Percentiles and standard deviation are approximated in distributed settings by averaging local - percentiles and standard deviations across ranks. This is mathematically imprecise but provides a - reasonable approximation for monitoring purposes. - - Args: - window_size (int): Maximum number of recent values to retain for statistics. - - Raises: - ValueError: If window_size is not positive. - """ - - def __init__(self, window_size: int = 1000): - if window_size <= 0: - raise ValueError(f"window_size must be positive, got {window_size}") - self.window_size = window_size - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=0.0, - agg_type=agg_type, - metadata={"values": deque(maxlen=self.window_size)}, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - local_agg_metric.metadata["values"].append(metric.value) - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - values = list(local_agg_metric.metadata["values"]) - if not values: - return [] - - values_tensor = torch.tensor(values, dtype=torch.float64) - n = len(values_tensor) - - # Compute stats from the tensor - sum_val = torch.sum(values_tensor).item() - mean_val = sum_val / n - min_val = torch.min(values_tensor).item() - max_val = torch.max(values_tensor).item() - - # Compute percentiles - percentile_definitions = torch.tensor([0.05, 0.5, 0.95], dtype=torch.float64) - p05_val, p50_val, p95_val = torch.quantile( - values_tensor, percentile_definitions - ).tolist() - - # Return multiple MetricStates with proper agg_types for distributed reduction - # NOTE: Percentiles use MEAN aggregation which approximates global percentiles - # by averaging local percentiles. - metrics = [ - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_mean", - value=mean_val, - agg_type=AggregationType.MEAN, - metadata={"sum": sum_val, "count": n}, - ), - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_min", - value=min_val, - agg_type=AggregationType.MIN, - metadata={}, - ), - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_max", - value=max_val, - agg_type=AggregationType.MAX, - metadata={}, - ), - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_p05", - value=p05_val, - agg_type=AggregationType.MEAN, - metadata={"sum": p05_val, "count": 1}, - ), - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_p50", - value=p50_val, - agg_type=AggregationType.MEAN, - metadata={"sum": p50_val, "count": 1}, - ), - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_p95", - value=p95_val, - agg_type=AggregationType.MEAN, - metadata={"sum": p95_val, "count": 1}, - ), - ] - - # Standard deviation is only well-defined for n > 1 - if n > 1: - std_val = torch.std(values_tensor).item() - metrics.append( - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_std", - value=std_val, - agg_type=AggregationType.MEAN, - metadata={"sum": std_val, "count": 1}, - ) - ) - return metrics - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - raise NotImplementedError( - "Metrics with AggregationType.STATS were converted to other " - "AggregationTypes for distributed reduction. finalize_dist_agg should not be called." - ) - - def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Convert deque to list for serialization.""" - serialized = metadata.copy() - if "values" in serialized: - serialized["values"] = list(serialized["values"]) - return serialized - - def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Convert list back to deque.""" - deserialized = metadata.copy() - if "values" in deserialized: - deserialized["values"] = deque( - deserialized["values"], maxlen=self.window_size - ) - return deserialized - - -class CategoricalCountAggHandler(AggregationHandler): - """AggHandler for CATEGORICAL_COUNT aggregation. Counts occurrences of categorical values - and expands into individual count metrics for each category.""" - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=0.0, - agg_type=agg_type, - metadata={"counts": Counter()}, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - local_agg_metric.metadata["counts"][metric.value] += 1 - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - # Expand categorical counts into individual metrics - results = [] - for category, count in local_agg_metric.metadata["counts"].items(): - results.append( - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_count_{category}", - value=count, - agg_type=AggregationType.SUM, - ) - ) - return results - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - raise NotImplementedError( - "Metrics with AggregationType.CATEGORICAL_COUNT were converted to other " - "AggregationType.SUM for distributed reduction. finalize_dist_agg should not be called." - ) - - def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Convert Counter to dict for serialization.""" - serialized = metadata.copy() - if "counts" in serialized: - serialized["counts"] = dict(serialized["counts"]) - return serialized - - def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Convert dict back to Counter.""" - deserialized = metadata.copy() - if "counts" in deserialized: - deserialized["counts"] = Counter(deserialized["counts"]) - return deserialized diff --git a/src/forge/data/dataset_metrics/metric_aggregator.py b/src/forge/data/dataset_metrics/metric_aggregator.py deleted file mode 100644 index 40d8075ce..000000000 --- a/src/forge/data/dataset_metrics/metric_aggregator.py +++ /dev/null @@ -1,344 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import ast -import logging -from collections import defaultdict -from typing import Any, Union - -import torch.distributed as dist - -from .metric_agg_handlers import ( - AggregationHandler, - CategoricalCountAggHandler, - MaxAggHandler, - MeanAggHandler, - MetricState, - MinAggHandler, - StatsAggHandler, - SumAggHandler, -) -from .metric_transform import AggregationType, Metric - -logger = logging.getLogger(__name__) - - -class MetricsAggregator: - """Aggregates metrics across datasets and distributed ranks using pluggable handlers. - - This class uses a handler-based strategy, where each aggregation type (SUM, MEAN, etc.) - has a corresponding AggregationHandler. It maintains a single state object for each - (source, metric_name) pair. - - Internal State Visualization: - { - ("alpaca", "tokens_seen"): MetricState(value=200.0, agg_type=SUM, ...), - ("alpaca", "avg_loss"): MetricState(value=0.01, agg_type=MEAN, metadata={'sum': ..., 'count': ...}), - ("slim_orca", "seq_len"): MetricState(agg_type=STATS, metadata={'values': deque([...])}), - } - - When preparing metrics for logging, the aggregator follows a two-phase process: - 1. Local Aggregation: Each rank aggregates its metrics independently - 2. Distributed Reduction: If in distributed mode, results are combined across ranks - - The aggregator's state is checkpointable, allowing training resumption. - - Args: - dist_window_size (int): Window size for StatsAggHandler tracking. - - Example: - >>> from forge.data.metrics import MetricsAggregator, Metric, AggregationType - >>> - >>> aggregator = MetricsAggregator() - >>> - >>> # Sample metrics from different batches - >>> batch1_metrics = [ - ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), - ... ] - >>> - >>> batch2_metrics = [ - ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), - ... ] - >>> - >>> # Update with metrics - >>> aggregator.update(batch1_metrics) - >>> aggregator.update(batch2_metrics) - >>> - >>> # Get final results - >>> results = aggregator.get_metrics_for_logging(prefix="train") - >>> # {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} - - Raises: - ValueError: If dist_window_size is not positive. - """ - - def __init__(self, dist_window_size: int = 1000): - if dist_window_size <= 0: - raise ValueError( - f"dist_window_size must be positive, got {dist_window_size}" - ) - - # Storage: {(source, metric_name): MetricState} - O(unique metrics) not O(samples) - self._metric_states: dict[tuple[str, str], MetricState] = {} - self._dist_window_size = dist_window_size - - # Track aggregation types for validation - prevents same metric name with different agg types - self._metric_agg_types: dict[tuple[str, str], AggregationType] = {} - - # Create handler registry - all handlers initialized upfront - self._handlers: dict[AggregationType, AggregationHandler] = { - AggregationType.SUM: SumAggHandler(), - AggregationType.MAX: MaxAggHandler(), - AggregationType.MIN: MinAggHandler(), - AggregationType.MEAN: MeanAggHandler(), - AggregationType.STATS: StatsAggHandler(dist_window_size), - AggregationType.CATEGORICAL_COUNT: CategoricalCountAggHandler(), - } - - def _validate_metric_consistency(self, metric: Union[Metric, MetricState]) -> None: - """Validate that metric name uses consistent aggregation type.""" - metric_key = (metric.source, metric.metric_name) - metric_name = metric.metric_name - - if metric_key in self._metric_agg_types: - existing_agg_type = self._metric_agg_types[metric_key] - if existing_agg_type != metric.agg_type: - raise ValueError( - f"Metric '{metric_name}' in dataset '{metric.source}' " - f"is already registered with aggregation type {existing_agg_type.value}, " - f"but a handler or user code tried to use it with type {metric.agg_type.value}. " - f"Use different metric names for different aggregation types." - ) - else: - # Track this metric's aggregation type - self._metric_agg_types[metric_key] = metric.agg_type - - def register_handler( - self, agg_type: AggregationType, handler: AggregationHandler - ) -> None: - """Register custom aggregation handler for specified type. - - Args: - agg_type (AggregationType): The aggregation type to handle - handler (AggregationHandler): Handler instance implementing the AggregationHandler interface - """ - # Warn if replacing a handler that's already in use - if agg_type in self._handlers and any( - state.agg_type == agg_type for state in self._metric_states.values() - ): - logger.warning( - f"Replacing handler for {agg_type} - aggregation type already in use by existing metrics. " - f"This may affect existing metric behavior." - ) - - self._handlers[agg_type] = handler - - def update(self, metrics: list[Metric]) -> None: - """Update (source, metric_name) metric state with new values. - - Args: - metrics (list[Metric]): List of metrics to update the state with - - Raises: - ValueError: If no handler is registered for a metric's aggregation type, - or if metric name conflicts with existing aggregation type. - """ - for metric in metrics: - # Same metric name must use same aggregation type - self._validate_metric_consistency(metric) - - metric_key = (metric.source, metric.metric_name) - handler = self._handlers.get(metric.agg_type) - - if handler is None: - raise ValueError( - f"No handler registered for aggregation type: {metric.agg_type}" - ) - - if metric_key not in self._metric_states: - self._metric_states[metric_key] = handler.initialize_metric_state( - metric.source, metric.metric_name, metric.agg_type - ) - - local_agg_metric = self._metric_states[metric_key] - handler.update(local_agg_metric, metric) # Mutates local_agg_metric - - def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: - """Get final metrics for logging in standard format. - - Args: - prefix (str): Prefix for metric names in the returned dictionary - - Returns: - dict[str, float]: Dictionary with keys like "{prefix}_{source}/{metric_name}" - and float values. For example, with `prefix="train"`, `source="alpaca"`, - `metric_name="loss"`, the key would be `train_alpaca/loss`. - """ - final_results = self._compute_unified_metrics() - - return { - f"{prefix}_{result.source}/{result.metric_name}": result.value - for result in final_results - } - - def _compute_unified_metrics(self) -> list[MetricState]: - """ - Compute metrics handling both local and distributed cases uniformly. - - Returns: - list[MetricState]: Final results ready for logging - """ - # Step 1: Get local results from all handlers (may expand stats/categoricals) - prepared_results = [] - for local_agg_metric in self._metric_states.values(): - handler = self._handlers[local_agg_metric.agg_type] - generated_metrics = handler.finalize_local_agg(local_agg_metric) - - # Validate each newly generated metric state immediately - for gen_metric in generated_metrics: - self._validate_metric_consistency(gen_metric) - - prepared_results.extend(generated_metrics) - - # Step 2: Apply distributed reduction if needed - if dist.is_initialized() and dist.get_world_size() > 1: - prepared_results = self._finalize_dist_agg(prepared_results) - - return prepared_results - - def _finalize_dist_agg( - self, local_agg_metrics: list[MetricState] - ) -> list[MetricState]: - """Apply distributed reduction to local results. - - Args: - local_agg_metrics (list[MetricState]): (source, metric_name) metric pairs from this rank - - Returns: - list[MetricState]: Reduced results combining all ranks - """ - world_size = dist.get_world_size() - - # Gather all results from all ranks - all_results = [None] * world_size - dist.all_gather_object(all_results, local_agg_metrics) - - # Group by (source, metric_name) for reduction - grouped = defaultdict(list) - for rank_results in all_results: - if rank_results: # Handle ranks with no metrics - for result in rank_results: - result_key = (result.source, result.metric_name) - grouped[result_key].append(result) - - # Apply handler-specific distributed reduction - reduced_results = [] - for result_key, results_list in grouped.items(): - if not results_list: - continue # Skip empty groups - - # All results for a key should have same agg_type - agg_type = results_list[0].agg_type - handler = self._handlers[agg_type] - reduced_result = handler.finalize_dist_agg(results_list) - reduced_results.append(reduced_result) - - return reduced_results - - def state_dict(self) -> dict[str, Any]: - """Serialize aggregator state for checkpointing. - - Returns: - dict[str, Any]: Serializable dictionary containing all aggregator state - """ - serializable_state = {} - required_agg_types = set() # Track aggregation types used in saved states - - for metric_key, local_agg_metric in self._metric_states.items(): - # Get handler for this result's aggregation type - handler = self._handlers[local_agg_metric.agg_type] - required_agg_types.add(local_agg_metric.agg_type) - - # Convert MetricState to serializable dict - result_dict = { - "source": local_agg_metric.source, - "metric_name": local_agg_metric.metric_name, - "value": local_agg_metric.value, - "agg_type": local_agg_metric.agg_type, - "metadata": handler.serialize_metadata(local_agg_metric.metadata), - } - - # Convert tuple key to string for JSON compatibility - serializable_state[str(metric_key)] = result_dict - - return { - "state": serializable_state, - "dist_window_size": self._dist_window_size, - "required_agg_types": list( - required_agg_types - ), # Save which handlers are needed - # Save which aggregation types are used for each metric - "metric_agg_types": { - str(k): v.value for k, v in self._metric_agg_types.items() - }, - } - - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Load aggregator state from checkpoint. - - Args: - state_dict (dict[str, Any]): Dictionary containing serialized aggregator state - - Raises: - ValueError: If required handlers are missing after checkpoint restore - """ - self._dist_window_size = state_dict.get("dist_window_size", 1000) - - # Sanity check: Ensure all required handlers are available - required_agg_types = state_dict.get("required_agg_types", []) - missing_handlers = [] - for agg_type in required_agg_types: - if agg_type not in self._handlers: - missing_handlers.append(agg_type) - - if missing_handlers: - raise ValueError( - f"Missing handlers for aggregation types: {missing_handlers}. " - f"Custom handlers must be re-registered before checkpoint restore." - ) - - deserialized_state = {} - for key_str, result_dict in state_dict["state"].items(): - # Convert string keys back to tuples - metric_key = ast.literal_eval(key_str) - - # Get handler for this aggregation type - agg_type = result_dict["agg_type"] - handler = self._handlers[agg_type] - - # Restore metadata using handler-specific deserialization - metadata = handler.deserialize_metadata(result_dict["metadata"]) - - # Create MetricState from dict - local_agg_metric = MetricState( - source=result_dict["source"], - metric_name=result_dict["metric_name"], - value=result_dict["value"], - agg_type=result_dict["agg_type"], - metadata=metadata, - ) - - deserialized_state[metric_key] = local_agg_metric - - self._metric_states = deserialized_state - - # Restore validation state - self._metric_agg_types = {} - for key_str, agg_type_str in state_dict.get("metric_agg_types", {}).items(): - key = ast.literal_eval(key_str) - self._metric_agg_types[key] = AggregationType(agg_type_str) diff --git a/src/forge/data/dataset_metrics/metric_transform.py b/src/forge/data/dataset_metrics/metric_transform.py deleted file mode 100644 index 2898c8e43..000000000 --- a/src/forge/data/dataset_metrics/metric_transform.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import Enum -from typing import Any, Union - -from forge.interfaces import Transform - - -@dataclass(frozen=True) -class Metric: - source: str - metric_name: str - value: Union[int, float, str] - agg_type: "AggregationType" - - -class AggregationType(Enum): - """Defines how a metric's value should be aggregated by the MetricsAggregator. - - Each type corresponds to a specific AggregationHandler that implements the logic - for initialization, updates, and distributed reduction. - """ - - SUM = "sum" - MEAN = "mean" - STATS = "distribution" - CATEGORICAL_COUNT = "categorical_count" - MAX = "max" - MIN = "min" - - -class MetricTransform(Transform, ABC): - """Applied to each dataset sample to generate per-sample metrics for training tracking. - - Creates Metric objects that are later aggregated by MetricsAggregator. This separation - of concerns ensures metrics are correctly aggregated even with multiple dataloader - workers and in distributed settings. - - The transform must be configured with a source via set_source() before use. - Each call to __call__ adds metrics to the sample's "metrics" key. - - Example: - >>> transform = DefaultTrainingMetricTransform() - >>> transform.set_source("alpaca") - >>> sample = {"tokens": [1, 2, 3]} - >>> result = transform(sample) - >>> # result["metrics"] contains list of Metric objects - """ - - def set_source(self, source: str) -> None: - """Called by the dataset to set the namespace for metrics. - - This is used to differentiate metrics from multiple datasets, for example, - "alpaca/tokens_seen" vs. "slim_orca/tokens_seen". - - Args: - source (str): Name of the dataset, used for logging and disambiguation. - """ - self.source = source - - @abstractmethod - def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: - """Generate metrics for a single sample. - - Args: - sample (dict[str, Any]): The sample dictionary to generate metrics from - - Returns: - list[Metric]: List of metrics generated for this sample - - Raises: - NotImplementedError: If subclass does not implement this method. - """ - raise NotImplementedError("Subclasses must implement _generate_metrics method") - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - if not hasattr(self, "source"): - raise RuntimeError( - "'transform.set_source' must be called before using the transform." - ) - - # Generate metrics for this sample - metrics = self._generate_metrics(sample) - - # Add to existing metrics list or create new one - if "metrics" not in sample: - sample["metrics"] = [] - sample["metrics"].extend(metrics) - return sample - - -class DefaultTrainingMetricTransform(MetricTransform): - """Generates common training metrics: samples seen, tokens seen, and sequence length. - - This transform detects the token key in a sample, checking for "tokens" - first and then falling back to "input_ids". - - For details on the base class behavior, see MetricTransform. - - Tracked metrics: - - samples_seen: Cumulative count of samples processed (SUM aggregation) - - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) - - seq_len: Distribution stats of sequence lengths (STATS aggregation) - - Example: - >>> transform = DefaultTrainingMetricTransform() - >>> transform.set_source("alpaca") - >>> - >>> sample = {"tokens": [1, 2, 3, 4, 5]} # 5 tokens - >>> metrics = transform._generate_metrics(sample) - >>> # This generates the following Metric objects: - >>> # [ - >>> # Metric(source="alpaca", metric_name="samples_seen", value=1, agg_type=AggregationType.SUM), - >>> # Metric(source="alpaca", metric_name="tokens_seen", value=5, agg_type=AggregationType.SUM), - >>> # Metric(source="alpaca", metric_name="seq_len", value=5, agg_type=AggregationType.STATS) - >>> # ] - """ - - def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: - # Determine token key - token_key = "tokens" if "tokens" in sample else "input_ids" - token_len = len(sample.get(token_key, [])) - - # Create metrics for this sample - return [ - Metric( - source=self.source, - metric_name="samples_seen", - value=1, - agg_type=AggregationType.SUM, - ), - Metric( - source=self.source, - metric_name="tokens_seen", - value=token_len, - agg_type=AggregationType.SUM, - ), - Metric( - source=self.source, - metric_name="seq_len", - value=token_len, - agg_type=AggregationType.STATS, - ), - ] diff --git a/src/forge/data/dataset_metrics/readme.md b/src/forge/data/dataset_metrics/readme.md deleted file mode 100644 index 76ec424e3..000000000 --- a/src/forge/data/dataset_metrics/readme.md +++ /dev/null @@ -1,176 +0,0 @@ -# forge Metrics Module - -## Overview - -The metrics module provides a robust system for tracking and aggregating training metrics across multiple datasets and distributed environments. It follows a **strategy pattern** design with pluggable aggregation handlers to efficiently handle different types of metrics. - -## Architecture Overview - -``` -┌────────────────────────────────────────────────────┐ -│ Training Loop │ -└─────────────────────┬──────────────────────────────┘ - │ -┌─────────────────────▼──────────────────────────────┐ -│ MetricTransform │ -│ • Applied to each sample │ -│ • Generates per-sample metrics │ -│ • Examples: tokens_seen, seq_len, samples_seen │ -└─────────────────────┬──────────────────────────────┘ - │ list[Metric] -┌─────────────────────▼──────────────────────────────┐ -│ MetricsAggregator │ -│ • Aggregates metrics across samples and ranks │ -│ • Uses pluggable AggregationHandlers │ -│ • Handles distributed reduction │ -└─────────────────────┬──────────────────────────────┘ - │ {prefix}_{source}/{metric_name} # prefix is "train", "val", etc. -┌─────────────────────▼──────────────────────────────┐ -│ Logging System │ -│ • W&B, TensorBoard, etc. │ -│ • Gets formatted metrics ready for logging │ -└────────────────────────────────────────────────────┘ -``` - -## File Structure - -- **`metric_transform.py`**: Defines `Metric`, `AggregationType`, and transform classes -- **`metric_agg_handlers.py`**: Aggregation strategy implementations -- **`metric_aggregator.py`**: Main aggregator orchestrating the handlers - -## Customizing metrics - -- **Custom transforms**: Extend `MetricTransform` for domain-specific metrics -- **Handler registration**: Register custom handlers for specialized aggregation needs - -####### -## TODO -## Move this from here to website docs -####### - -## Core Components - -### 1. MetricTransform -Generates per-sample metrics during data processing. - -**Key Features:** -- Applied to each sample in the dataset -- Creates `Metric` objects with dataset name, metric name, value, and aggregation type -- Handles dataset namespacing for multi-dataset scenarios - -**Example Usage:** -```python -from forge.data.metrics import DefaultTrainingMetricTransform, AggregationType - -transform = DefaultTrainingMetricTransform() -transform.set_source("alpaca") - -# Applied to each sample -sample = {"tokens": [1, 2, 3, 4, 5]} -sample = transform(sample) -# sample["metrics"] now contains: -# [ -# Metric(source="alpaca", name="samples_seen", value=1, agg_type=AggregationType.SUM), -# Metric(source="alpaca", name="tokens_seen", value=5, agg_type=AggregationType.SUM), -# Metric(source="alpaca", name="seq_len", value=5, agg_type=AggregationType.STATS) -# ] -``` - -### 2. MetricsAggregator -Efficiently aggregates metrics across samples and distributed ranks. - -**Key Features:** -- Handler-based strategy pattern for different aggregation types -- Distributed-aware with automatic rank reduction -- Checkpointable state for training resumption -- Keep track of (metric, dataset) pairs - -**Aggregation Types (at the time of writing):** -- `SUM`: Cumulative totals (e.g., total tokens processed) -- `MEAN`: Running averages (e.g., average loss) -- `MAX/MIN`: Extrema tracking (e.g., max sequence length seen) -- `STATS`: Statistical summaries (mean, min, max, percentiles) -- `CATEGORICAL_COUNT`: Category cumulative counts (e.g. num of samples from a given category) - -**Example Usage:** -```python -from forge.data.metrics import MetricsAggregator, Metric, AggregationType - -# Create aggregator -aggregator = MetricsAggregator() - -# Sample metrics from different batches -batch1_metrics = [ - Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), -] - -batch2_metrics = [ - Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), -] - -# Update with metrics -aggregator.update(batch1_metrics) -aggregator.update(batch2_metrics) - -# Get final results -results = aggregator.get_metrics_for_logging(prefix="train") -# {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} -``` - -### 3. AggregationHandlers -Pluggable strategies for different aggregation patterns. - -``` -AggregationHandler (ABC) -├── SumAggHandler # value += metric.value -├── MeanAggHandler # tracks sum and count -├── MaxAggHandler # value = max(value, metric.value) -├── MinAggHandler # value = min(value, metric.value) -├── StatsAggHandler # maintains value window + stats -└── CategoricalCountAggHandler # Counter for categories -``` - -**Custom Handler Example:** -```python -class CustomAggHandler(AggregationHandler): - def initialize_metric_state(self, source, metric_name, agg_type): - return MetricState( - source=source, - metric_name=metric_name, - value=, # should change - agg_type=agg_type, - metadata={} # may need to change - ) - - def update(self, local_agg_metric, metric): - ... - - def finalize_local_agg(self, local_agg_metric): - ... - - def finalize_dist_agg(self, local_agg_metrics): - ... - -# Register with aggregator -aggregator.register_handler(AggregationType.CUSTOM, CustomAggHandler()) -``` - -## Distributed Training Support - -The metrics system automatically handles distributed environments: - -1. **Local Aggregation**: Each rank aggregates its own metrics -2. **Distributed Reduction**: Results are combined across ranks using `all_gather_object` -3. **Type-Aware Reduction**: Each aggregation type uses appropriate reduction (sum, mean, max, etc.) - -**Distributed Flow:** -``` -Rank 0: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] -Rank 1: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] - ↓ - AllGather + Reduce - ↓ - Final Results [(ds1, metric1), (ds1, metric2)] -``` diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py index 6be68b41b..2b8701b26 100644 --- a/src/forge/data/datasets/hf_dataset.py +++ b/src/forge/data/datasets/hf_dataset.py @@ -12,13 +12,9 @@ from datasets import load_dataset from datasets.distributed import split_dataset_by_node -from forge.data.dataset_metrics import ( - AggregationType, - DefaultTrainingMetricTransform, - Metric, - MetricTransform, -) +from forge.data.metric_transform import DefaultDatasetMetricTransform, MetricTransform from forge.interfaces import Transform +from forge.observability.metrics import Metric, Reduce from .dataset import DatasetInfo, InfiniteTuneIterableDataset @@ -86,7 +82,7 @@ def __init__( self._weight = weight if weight is not None else 1.0 # Create default transform if not provided - self._metric_transform = metric_transform or DefaultTrainingMetricTransform() + self._metric_transform = metric_transform or DefaultDatasetMetricTransform() # Auto-generate dataset name if not provided if dataset_name is None: @@ -237,18 +233,18 @@ def __iter__(self) -> Iterator[dict[str, Any]]: # .map is applied lazily and the advantage would be to leverage caching. sample = self._apply_transforms(sample) - # Track the number of epochs completed for each dataset. This is - # especially useful when interleaving multiple datasets, but - # also necessary to track dataset-level metrics. - metric_num_epochs = Metric( - source=self.info.name, - metric_name="num_epochs", - value=self._num_epochs, - agg_type=AggregationType.MAX, - ) + # Track the number of epochs completed for each dataset. + # This is especially useful when interleaving multiple datasets. if "metrics" not in sample: sample["metrics"] = [] - sample["metrics"].append(metric_num_epochs) + + sample["metrics"].append( + Metric( + key=f"dataset/{self.info.name}/num_epochs", + value=self._num_epochs, + reduction=Reduce.MAX, + ) + ) samples_yielded += 1 yield sample diff --git a/src/forge/data/datasets/packed.py b/src/forge/data/datasets/packed.py index d09c158c0..93a21b85e 100644 --- a/src/forge/data/datasets/packed.py +++ b/src/forge/data/datasets/packed.py @@ -16,7 +16,7 @@ from torchdata.stateful_dataloader import Stateful from forge.data import CROSS_ENTROPY_IGNORE_IDX -from forge.data.dataset_metrics import AggregationType, Metric +from forge.observability.metrics import Metric, Reduce from .dataset import DatasetInfo, InfiniteTuneIterableDataset @@ -605,13 +605,13 @@ def finalize_pack( # Add padding percentage metric if target_tokens_per_pack > 0: padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2) - padding_metric = Metric( - source=self.dataset_name, - metric_name="pct_of_tokens_padded", - value=padding_pct, - agg_type=AggregationType.MEAN, + pack["metrics"].append( + Metric( + key=f"dataset/{self.dataset_name}/pct_of_tokens_padded", + value=padding_pct, + reduction=Reduce.MEAN, + ) ) - pack["metrics"].append(padding_metric) # Concatenate tensor lists and handle other keys result = { @@ -635,7 +635,7 @@ def finalize_pack( if pack["input_pos"] else torch.empty(0, dtype=torch.long) ), - # "metrics": pack["metrics"], + "metrics": pack["metrics"], } # Handle arbitrary keys that aren't tensors - keep as lists @@ -853,13 +853,13 @@ def finalize_pack( # Add padding percentage metric if target_tokens_per_pack > 0: padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2) - padding_metric = Metric( - source=self.dataset_name, - metric_name="pct_of_tokens_padded", - value=padding_pct, - agg_type=AggregationType.MEAN, + pack["metrics"].append( + Metric( + key=f"dataset/{self.dataset_name}/pct_of_tokens_padded", + value=padding_pct, + reduction=Reduce.MEAN, + ) ) - pack["metrics"].append(padding_metric) # Concatenate tensor lists and handle other keys result = { diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index fca97f912..a354eec95 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -9,7 +9,7 @@ import torch from forge.data import CROSS_ENTROPY_IGNORE_IDX -from forge.data.dataset_metrics import DefaultTrainingMetricTransform +from forge.data.metric_transform import DefaultDatasetMetricTransform from forge.data.utils import mask_messages, TuneMessage from forge.interfaces import Transform @@ -200,7 +200,7 @@ def sft_iterable_dataset( message_transform=message_transform, model_transform=model_transform, output_transform=output_transform, - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), shuffle_buffer_size=shuffle_buffer_size, weight=weight, seed=seed, diff --git a/src/forge/data/metric_transform.py b/src/forge/data/metric_transform.py new file mode 100644 index 000000000..076556664 --- /dev/null +++ b/src/forge/data/metric_transform.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +from forge.interfaces import Transform +from forge.observability.metrics import Metric, Reduce + + +class MetricTransform(Transform): + """ + Base class for transforms that collect observability metrics from dataset samples. + + This class provides a foundation for implementing dataset-level metric collection + during data processing pipelines. Subclasses should override the __call__ method + to add specific metrics to each sample that passes through the transform. + + Metrics are collected as `forge.observability.metrics.Metric` objects and made available + in batch["metrics"]. + + Attributes: + source (str, optional): The source name for metrics, typically the dataset name. + This is used as a prefix in metric keys to distinguish metrics from different + data sources. + + Example: + >>> transform = SomeMetricTransform() + >>> transform.set_source("training_data") + >>> processed_sample = transform(sample) + >>> # Metrics are automatically added to sample["metrics"] + """ + + def __init__(self): + self.source = None + + def set_source(self, source: str): + """Set the source name for metrics (typically the dataset name).""" + self.source = source + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Transform a sample by adding metrics to it.""" + return sample + + +class DefaultDatasetMetricTransform(MetricTransform): + """ + Collects basic dataset processing metrics during data pipeline execution. + + Metrics collected: + - samples_processed: Total number of samples that have passed through this transform (SUM) + - tokens_processed: Total number of tokens processed across all samples (SUM) + - mean_seq_len: Average sequence length across samples (MEAN) + - max_seq_len: Maximum sequence length observed (MAX) + - min_seq_len: Minimum sequence length observed (MIN) + + Note: Token-related metrics are only collected if the sample contains a 'tokens' field. + Sequence length is measured as the number of tokens in each sample. + + Example: + >>> collector = DefaultDatasetMetricTransform() + >>> collector.set_source("training_data") + >>> sample = {"tokens": ["hello", "world"]} + >>> processed_sample = collector(sample) + >>> # Metrics are automatically added to processed_sample["metrics"] + """ + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + if "metrics" not in sample: + sample["metrics"] = [] + + source_name = self.source or "dataset" + + # Add samples_processed metric + sample["metrics"].append( + Metric( + key=f"dataset/{source_name}/samples_processed", + value=1, + reduction=Reduce.SUM, + ) + ) + + # Add token-based metrics if tokens are present + if "tokens" in sample: + token_count = len(sample.get("tokens", [])) + + sample["metrics"].extend( + [ + Metric( + key=f"dataset/{source_name}/tokens_processed", + value=token_count, + reduction=Reduce.SUM, + ), + Metric( + key=f"dataset/{source_name}/mean_seq_len", + value=token_count, + reduction=Reduce.MEAN, + ), + Metric( + key=f"dataset/{source_name}/max_seq_len", + value=token_count, + reduction=Reduce.MAX, + ), + Metric( + key=f"dataset/{source_name}/min_seq_len", + value=token_count, + reduction=Reduce.MIN, + ), + ] + ) + + return sample diff --git a/src/forge/observability/__init__.py b/src/forge/observability/__init__.py index b970e57fa..0844b56ee 100644 --- a/src/forge/observability/__init__.py +++ b/src/forge/observability/__init__.py @@ -12,9 +12,9 @@ from .metrics import ( BackendRole, ConsoleBackend, - get_actor_name_with_rank, get_logger_backend_class, LoggerBackend, + LoggingMode, MaxAccumulator, MeanAccumulator, Metric, @@ -29,12 +29,12 @@ WandbBackend, ) from .perf_tracker import trace, Tracer +from .utils import detect_actor_name_from_call_stack, get_actor_name_with_rank __all__ = [ # Main API functions "record_metric", "reduce_metrics_states", - "get_actor_name_with_rank", "get_logger_backend_class", "get_or_create_metric_logger", # Performance tracking @@ -45,6 +45,10 @@ "BackendRole", # Enums "Reduce", + "LoggingMode", + # Utility functions + "detect_actor_name_from_call_stack", + "get_actor_name_with_rank", # Actor classes "GlobalLoggingActor", "LocalFetcherActor", diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index fae11556f..935d97ecb 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -15,9 +15,11 @@ BackendRole, get_logger_backend_class, LoggerBackend, + LoggingMode, MetricCollector, reduce_metrics_states, ) +from forge.observability.utils import detect_actor_name_from_call_stack if MONARCH_HOSTMESH_V1.get_value(): from monarch._src.actor.v1.host_mesh import this_proc @@ -33,6 +35,7 @@ async def get_or_create_metric_logger( proc_mesh: ProcMesh | None = None, + process_name: str | None = None, ) -> "GlobalLoggingActor": """Initializes a LocalFetcherActor in the specified process mesh (or current process if None), if not already initialized, registers it with the GlobalLoggingActor and returns the @@ -46,6 +49,8 @@ async def get_or_create_metric_logger( Args: proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `monarch.actor.this_proc()`. + process_name: Optional process name (e.g., "TrainActor", "GeneratorActor") for logging. + If None, will auto-detect from call stack or default to "UnknownActor" if not found. Returns: GlobalLoggingActor: The global logging controller. @@ -59,12 +64,12 @@ async def get_or_create_metric_logger( from forge.observability.metrics import record_metric # Main process setup - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") # Initialize logging backends await mlogger.init_backends({ - "console": {"reduce_across_ranks": True}, - "wandb": {"project": "my_project", "reduce_across_ranks": False} + "console": {"logging_mode": "global_reduce"}, + "wandb": {"project": "my_project", "logging_mode": "per_rank_reduce"} }) # Initialize services... @@ -72,13 +77,17 @@ async def get_or_create_metric_logger( # Training loop for step in range(max_steps): - record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN) + record_metric("loss", 1.2, reduction_type=Reduce.MEAN) # ... training code with record_metric() calls ... await mlogger.flush(step) # Log metrics for this step # Shutdown await mlogger.shutdown() """ + + if process_name is None: + process_name = detect_actor_name_from_call_stack() + # Get or create the singleton global logger global _global_logger if _global_logger is None: @@ -107,7 +116,7 @@ async def get_or_create_metric_logger( # Setup local_fetcher_actor if needed (unless disabled by environment flag) if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value(): local_fetcher_actor = proc.spawn( - "local_fetcher_actor", LocalFetcherActor, global_logger + "local_fetcher_actor", LocalFetcherActor, global_logger, process_name ) await global_logger.register_fetcher.call_one(local_fetcher_actor, proc) proc._local_fetcher = local_fetcher_actor # pyre-ignore @@ -123,8 +132,13 @@ class LocalFetcherActor(Actor): GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector """ - def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None: + def __init__( + self, + global_logger: Union["GlobalLoggingActor", None] = None, + process_name: str | None = None, + ) -> None: self.global_logger = global_logger + self.process_name = process_name # Passed to MetricCollector for logging _is_initialized = False @endpoint @@ -151,10 +165,22 @@ async def init_backends( self, metadata_per_primary_backend: dict[str, dict[str, Any]], config: dict[str, Any], + global_step: int = 0, ) -> None: - """Init local (per-rank) logger backends and MetricCollector.""" + """Init local (per-rank) logger backends and MetricCollector. + + Args: + metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state. + config (dict[str, Any]): Backend configurations with logging modes and settings. + global_step (int): Initial step for metrics. + """ collector = MetricCollector() - await collector.init_backends(metadata_per_primary_backend, config) + await collector.init_backends( + metadata_per_primary_backend, + config, + global_step, + process_name=self.process_name, + ) @endpoint async def shutdown(self) -> None: @@ -163,22 +189,19 @@ async def shutdown(self) -> None: class GlobalLoggingActor(Actor): - """Coordinates metric logging across all ranks for every training step. + """Coordinates metric logging across all ranks for every global step. Supports multiple logging backends (e.g., WandB, TensorBoard, etc.), for per-rank and/or global reduction logging modes. - If a backend config has flag `reduce_across_ranks=False`, an instance of the backend - is initialized per-rank, otherwise it is done once globally. - This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor is automatically spawned per-rank in `forge.controller.provisioner.py` and registered with this actor. The LocalFetcherActor is responsible for instantiating - the per-rank MetricCollector. + the per-rank MetricCollector and working as a bridge between GlobalLoggingActor and processes. In summary, the flow is: - - GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector - - GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush + - GlobalLoggingActor.init_backends() -> LocalFetcherActor.init_backends() -> per-rank MetricCollector.init_backends() + - GlobalLoggingActor.flush() -> LocalFetcherActor.flush() -> per-rank MetricCollector.flush """ def __init__(self): @@ -187,45 +210,86 @@ def __init__(self): self.global_logger_backends: dict[str, LoggerBackend] = {} self.metadata_per_primary_backend: dict[str, dict[str, Any]] = {} + def _validate_backend_config( + self, backend_name: str, config: dict[str, Any] + ) -> dict[str, Any]: + """Validate and normalize backend configuration.""" + if "logging_mode" not in config: + logger.debug( + f"logging_mode not provided for backend {backend_name}. Defaulting to global_reduce." + ) + + mode_str = config.get("logging_mode", "global_reduce") + mode = LoggingMode(mode_str) + + # Validate per_rank_share_run configuration + share_run = config.get("per_rank_share_run", False) + if mode == LoggingMode.GLOBAL_REDUCE and share_run: + logger.warning( + f"{backend_name}: per_rank_share_run ignored in {mode.value} mode. " + "Set it to False or change logging_mode to per rank." + ) + + # WandB-specific warning for suboptimal configuration + if ( + backend_name == "wandb" + and mode == LoggingMode.PER_RANK_REDUCE + and share_run + ): + logger.warning( + "WandB: Using 'per_rank_reduce' with 'per_rank_share_run=True' is not recommended. " + "This configuration can lead to confusing metrics where reduced values from multiple ranks " + "are written to the same run. Consider either:\n" + " 1. Set 'per_rank_share_run=False' to create separate runs per rank, OR\n" + " 2. Use 'per_rank_no_reduce' for real-time streaming to a shared run" + ) + + return { + **config, + "logging_mode": mode, + } + @endpoint async def init_backends(self, config: dict[str, Any]) -> None: - """ - Sets config in global actor, so other actors can get it, then eagerly initializes backend and MetricCollectors + """Sets config in global actor, initializes primary backends and eagerly initializes MetricCollectors in all registered fetchers. - A backend is always initialized in the controller (primary backend) and can be used as a logger or as a source - for metadata to be shared with per-rank backends, e.g. shared run IDs for wandb. - - The backend instantiation is controlled by the backend config flag `reduce_across_ranks`: if False, - a per-rank backend is initialized, i.e. if there are 2 ranks, each will have its own backend, - and will log independently, i.e. each rank will have its own run in wandb. - - Else, if True, the GlobalLoggingActor will fetch all local metrics collectors to get their states - and reduce them to a single value, which will be logged by the primary backend in this controller. + The backend instantiation is controlled by the logging_mode field. Primary backends + (instantiated in the controller) can provide metadata to be shared with secondary backends on ranks, + e.g. shared run IDs for WandB. For details on logging modes, see `forge.observability.metrics.LoggingMode`. Args: - config (dict[str, Any]): Config for metric logging where keys are backend names, - e.g. {"console": {"reduce_across_ranks": True}, "wandb": {"reduce_across_ranks": False}} + config (dict[str, Any]): Config for metric logging where keys are backend names. + e.g. { + "console": {"logging_mode": "global_reduce"}, + "wandb": {"project": "my_project", "logging_mode": "per_rank_no_reduce"} + } + + Raises: + ValueError: If backend config is invalid or missing required fields. """ - self.config = config + self.config = {} + # Validate and normalize each backend config for backend_name, backend_config in config.items(): + self.config[backend_name] = self._validate_backend_config( + backend_name, backend_config + ) + + # Initialize backends based on logging mode + for backend_name, backend_config in self.config.items(): + mode = backend_config["logging_mode"] + backend = get_logger_backend_class(backend_name)(backend_config) await backend.init(role=BackendRole.GLOBAL) - # Extract metadata from primary logger to be shared with secondary loggers - # and store it - reduce_across_ranks = backend_config.get("reduce_across_ranks", True) - if not reduce_across_ranks: - primary_backend_metadata = ( - backend.get_metadata_for_secondary_ranks() or {} - ) - self.metadata_per_primary_backend[ - backend_name - ] = primary_backend_metadata + # Extract metadata from primary logger to be shared with per-rank loggers + if mode != LoggingMode.GLOBAL_REDUCE: + primary_metadata = backend.get_metadata_for_secondary_ranks() or {} + self.metadata_per_primary_backend[backend_name] = primary_metadata - # Store global logger backends - if reduce_across_ranks: + # Store global logger backends for later flush + if mode == LoggingMode.GLOBAL_REDUCE: self.global_logger_backends[backend_name] = backend # Eager init collectors on all registered fetchers in parallel, passing primary states and config @@ -279,19 +343,21 @@ async def flush(self, global_step: int) -> None: config = self.config if config is None: logger.warning( - "GlobalLoggingActor flush() called before init_backends(). " - "No backends will be flushed." + "Cannot flush collected metrics. GlobalLoggingActor.flush() called before init_backends()." + " No backends will be flushed. Please call in your main file:\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" + "`await mlogger.init_backends.call_one(logging_config)`\n" ) return - # if reduce_across_ranks=True, we need to reduce the states from all ranks - # and log with the primary backend + + # Check if need to do reduce and retrieve states from fetchers requires_reduce = any( - backend_config.get("reduce_across_ranks", True) + backend_config["logging_mode"] == LoggingMode.GLOBAL_REDUCE for backend_config in config.values() ) logger.debug( - f"Global flush for global_step {global_step}: {len(self.fetchers)} fetchers" + f"Global flush for global step {global_step}: {len(self.fetchers)} fetchers" ) # Broadcast flush to all fetchers @@ -304,21 +370,25 @@ async def flush(self, global_step: int) -> None: ) if requires_reduce: - # Handle exceptions and extract values from ValueMesh results - all_local_states = [] - for result in results: - if isinstance(result, BaseException): - logger.warning(f"Flush failed on a fetcher: {result}") - continue - - # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}] - for gpu_info, local_metric_state in result.items(): - if isinstance(local_metric_state, dict): - all_local_states.append(local_metric_state) - else: - logger.warning( - f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}" - ) + + def extract_values_from_valuemesh(results): + all_local_states = [] + for result in results: + if isinstance(result, BaseException): + logger.warning(f"Flush failed on a fetcher: {result}") + continue + + # result is a generator that outputs a pair [{'gpus': i/N}, {metric_key1: metric_state1, ...}}] + for gpu_info, local_metric_state in result.items(): + if isinstance(local_metric_state, dict): + all_local_states.append(local_metric_state) + else: + logger.warning( + f"Unexpected result from fetcher. {gpu_info=}, {local_metric_state=}" + ) + return all_local_states + + all_local_states = extract_values_from_valuemesh(results) if not all_local_states: logger.warning(f"No states to reduce for global_step {global_step}") @@ -327,12 +397,9 @@ async def flush(self, global_step: int) -> None: # Reduce metrics from states reduced_metrics = reduce_metrics_states(all_local_states) - # Log to each global logger_backend - for ( - logger_backend_name, - logger_backend, - ) in self.global_logger_backends.items(): - await logger_backend.log(reduced_metrics, global_step) + # Log to global backends + for backend_name, backend in self.global_logger_backends.items(): + await backend.log_batch(reduced_metrics, global_step) @endpoint def has_fetcher(self, name: str | ProcMesh) -> bool: diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 3ce849ad2..282fd003e 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -13,8 +13,9 @@ from typing import Any import pytz -from monarch.actor import context, current_rank +from monarch.actor import current_rank +from forge.observability.utils import get_actor_name_with_rank from forge.util.logging import log_once logger = logging.getLogger(__name__) @@ -31,6 +32,32 @@ class BackendRole(Enum): GLOBAL = "global" +class LoggingMode(Enum): + """Metric logging behavior for distributed training scenarios. + + Each mode serves different observability needs: + + GLOBAL_REDUCE = "global_reduce" + Best for: Metrics that are best visualized as a single value per step. + Behavior: All ranks accumulate → controller reduces → single log entry + Example use: 8 ranks training, want 1 loss value per training step averaged across all + + PER_RANK_REDUCE = "per_rank_reduce" + Best for: Per-rank performance metrics, debugging individual rank behavior + Behavior: Each rank accumulates + logs its own reduced values + Example use: Monitor GPU utilization per rank, get 8 separate log entries per step + + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + Best for: Real-time streaming, time-series debugging + Behavior: Raw values logged immediately on record_metric() calls. Ignores reduce type. + Example use: See what every rank is doing in real time. + """ + + GLOBAL_REDUCE = "global_reduce" + PER_RANK_REDUCE = "per_rank_reduce" + PER_RANK_NO_REDUCE = "per_rank_no_reduce" + + class Reduce(Enum): MEAN = "mean" SUM = "sum" @@ -55,6 +82,12 @@ class Metric: """Container for metric data including key, value, reduction type, and timestamp. Timestamp is automatically set to current EST time if not provided. + + Args: + key: str + value: Any + reduction: Reduce + timestamp: Optional[float] = None """ key: str @@ -68,55 +101,6 @@ def __post_init__(self): self.timestamp = datetime.now(pytz.UTC).timestamp() -def get_actor_name_with_rank() -> str: - """ - Extracts actor information from Monarch context to form a logging name. - - Returns: - str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). - Falls back to "UnknownActor" if context unavailable. - """ - # Add more defensive checks - ctx = context() - if ctx is None or ctx.actor_instance is None: - logger.warning("Context unavailable, using fallback actor name for logging.") - return "UnknownActor" - - actor_instance = ctx.actor_instance - rank = current_rank() - - actor_id_full = str(actor_instance.actor_id) - - # Parse the actor_id - parts = actor_id_full.split(".") - rank_name = "UnknownActor" # fallback - if len(parts) >= 2: - world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" - actor_part = parts[1] # e.g., "TestActorConfigured[0]" - - # Extract world ID and proc rank - world_id = world_part.split("[")[0] if "[" in world_part else world_part - - # Extract clean actor name (remove "Configured" suffix if present) - if "[" in actor_part: - actor_name = actor_part.split("[")[0] # e.g., "TestActorConfigured" - if actor_name.endswith("Configured"): - actor_name = actor_name[:-10] # Remove "Configured" - else: - actor_name = actor_part - - # Use last 4 characters of world_id as replica identifier - # This is deterministic, readable, and works for any number of replicas - replica_id = world_id[-4:] if len(world_id) >= 4 else world_id - - # Use current_rank().rank as the local rank within the replica - local_rank = rank.rank - - rank_name = f"{actor_name}_{replica_id}_r{local_rank}" - - return rank_name - - def record_metric(key: str, value: Any, reduction: Reduce = Reduce.MEAN) -> None: """Thin wrapper to send metrics to per-rank local MetricCollectors. @@ -150,11 +134,11 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri states is more precise than merging locally reduced metrics. Args: - states (list[dict[str, dict[str, Any]]]): List of state of one or more metrics, + states (list[dict[str, dict[str, Any]]]): list of state of one or more metrics, normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`. Returns: - list[Metric]: List of reduced metrics + list[Metric]: list of reduced metrics Example: states = [ @@ -400,22 +384,23 @@ def reset(self) -> None: class MetricCollector: """Per-rank singleton for accumulating, retrieving and flushing metrics to backends. - A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, - the backend is instantiated per-rank, in the MetricCollector, otherwise it is instantiated once globally, - in the GlobalLoggingActor. + Supports multiple logging backends, each with different logging modes. + For options, check `forge.observability.metrics.LoggerBackend` and `forge.observability.metrics.LoggingMode`. - - Ensures one instance per process; actors call record_metric() which delegates here. + Properties: + - Ensures one instance per rank; + - Using `record_metric()` delegates here; - Init via GlobalLoggingActor -> LocalFetcherActor -> per-rank MetricCollector; - GlobalLoggingActor flushes trigger reductions and log for any locally setup backend. Can optionally also - return non-reduced states for global aggregation. This can be different for each backend. - - Resets accumulators post-flush to avoid leaks across train steps; + return non-reduced states for global aggregation. + - Resets accumulators post-flush to avoid leaks across steps; """ _instances: dict[int, "MetricCollector"] = {} _singleton_rank: int def __new__(cls): - """Singleton per-rank, ensures one instance per process.""" + """Singleton per-rank, ensures one instance per rank.""" rank = current_rank().rank if rank not in cls._instances: @@ -436,31 +421,59 @@ def __init__(self) -> None: self.accumulators: dict[str, MetricAccumulator] = {} self.rank = current_rank().rank - self.logger_backends: list[LoggerBackend] = [] + self.per_rank_reduce_backends: list[LoggerBackend] = [] + self.per_rank_no_reduce_backends: list[LoggerBackend] = [] + self.global_step: int = 0 # Set on `init_backends` and updated on `flush` self._is_initialized = False async def init_backends( self, metadata_per_primary_backend: dict[str, dict[str, Any]] | None, config: dict[str, Any], + global_step: int = 0, + process_name: str | None = None, ) -> None: - """A logger is represented by a backend, i.e. wandb backend. If reduce_across_ranks=False, - the backend is instantiated per-rank, in the MetricCollector, otherwise it is only instantiated - once globally. + """Initialize per-rank logger backends and MetricCollector state. + + A logger backend is represented by a backend class (e.g. WandBBackend, ConsoleBackend). + Backends are categorized by their logging_mode. For details, see `forge.observability.metrics.LoggingMode`. Args: - metadata_per_primary_backend (dict[str, dict[str, Any]] | None): Metadata from primary - logger backend, e.g., {"wandb": {"run_id": "abc123"}}. - config (dict[str, Any]): Logger backend configuration, e.g. {"wandb": {"project": "my_project"}}. + metadata_per_primary_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from primary + logger backends for backends that require shared state across processes, e.g., + {"wandb": {"shared_run_id": "abc123"}}. + config (Dict[str, Any]): Backend configurations where each key is a backend name + and value contains logging_mode and backend-specific settings. + e.g., {"wandb": {"logging_mode": "per_rank_no_reduce", "project": "my_proj"}} + global_step (int, default 0): Initial step for logging. Can be used when + resuming from a checkpoint. + process_name (str | None): The meaningful process name for logging. """ if self._is_initialized: logger.debug(f"Rank {self.rank}: MetricCollector already initialized") return - # instantiate local backends if any + # Initialize step tracking for immediate logging + self.global_step = global_step + + self.per_rank_reduce_backends: list[LoggerBackend] = [] + self.per_rank_no_reduce_backends: list[LoggerBackend] = [] + + # Initialize backends based on logging mode + # logging_mode is expected to be a LoggingMode enum from GlobalLoggingActor validation for backend_name, backend_config in config.items(): - if backend_config.get("reduce_across_ranks", True): - continue # Skip local backend instantiation and use global instead + mode = backend_config["logging_mode"] + + # Defensive check - logging_mode should already be a LoggingMode enum + if not isinstance(mode, LoggingMode): + raise TypeError( + f"Expected LoggingMode enum for {backend_name}.logging_mode, got {type(mode).__name__}: {mode}." + ) + + # Skip local instantiation for GLOBAL_REDUCE + # Backend will be instantiated in GlobalLoggingActor + if mode == LoggingMode.GLOBAL_REDUCE: + continue # get metadata from primary backend if any primary_metadata = {} @@ -468,27 +481,36 @@ async def init_backends( primary_metadata = metadata_per_primary_backend.get(backend_name, {}) # instantiate local backend - logger_backend = get_logger_backend_class(backend_name)(backend_config) - await logger_backend.init( - role=BackendRole.LOCAL, primary_logger_metadata=primary_metadata + backend = get_logger_backend_class(backend_name)(backend_config) + await backend.init( + role=BackendRole.LOCAL, + primary_logger_metadata=primary_metadata, + process_name=process_name, ) - self.logger_backends.append(logger_backend) + + # Categorize by logging mode + if mode == LoggingMode.PER_RANK_NO_REDUCE: + self.per_rank_no_reduce_backends.append(backend) + else: + self.per_rank_reduce_backends.append(backend) self._is_initialized = True def push(self, metric: Metric) -> None: """Process a metric according to configured logging modes. - Args: - metric: Metric dataclass containing key, value, reduction type, and timestamp. + Behavior depends on backend modes: + - PER_RANK_NO_REDUCE: Stream metric immediately to backends + - PER_RANK_REDUCE/GLOBAL_REDUCE: Accumulate for per step batch logging - Raises: - TypeError: If metric is not a Metric object. + Args: + metric (): Metric dataclass + metric (Metric): Metric dataclass Example: collector = MetricCollector() metric = Metric("loss", 0.5, Reduce.MEAN) - collector.push(metric) + collector.push(metric) # Streams immediately if no_reduce, else accumulates """ if not self._is_initialized: log_once( @@ -498,7 +520,7 @@ def push(self, metric: Metric) -> None: "Skipping metric collection. Metric logging backends (e.g. wandb) were not initialized." " This happens when you try to use `record_metric` before calling `init_backends`." " To disable this warning, please call in your main file:\n" - "`mlogger = await get_or_create_metric_logger()`\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" "`await mlogger.init_backends.call_one(logging_config)`\n" "or set env variable `FORGE_DISABLE_METRICS=True`" ), @@ -507,7 +529,13 @@ def push(self, metric: Metric) -> None: # Validate metric object if not isinstance(metric, Metric): - raise TypeError(f"Expected {Metric} object, got {type(metric)}") + raise TypeError( + f"Expected {Metric} object, got {metric} of type {type(metric)}" + ) + + # For PER_RANK_NO_REDUCE backends: stream without reduce + for backend in self.per_rank_no_reduce_backends: + backend.log_stream(metric=metric, global_step=self.global_step) # Always accumulate for reduction and state return key = metric.key @@ -527,7 +555,7 @@ async def flush( return_state (bool): Used by GlobalLoggingActor for reduction across all ranks. If False, returns empty dict, else returns the state of all metrics collected. Returns: - dict[str, dict[str, dict[str, Any]]]: Dict of {metric_key: metric_state}, + dict[str, dict[str, dict[str, Any]]]: dict of {metric_key: metric_state}, e.g., {"loss": {"reduction_type": "mean", "sum": 1.2, "count": 3}}. """ if not self._is_initialized: @@ -536,7 +564,7 @@ async def flush( level=logging.WARNING, msg="Cannot flush collected metrics. MetricCollector.flush() called before init_backends()." "\nPlease call in your main file:\n" - "`mlogger = await get_or_create_metric_logger()`\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" "`await mlogger.init_backends.call_one(logging_config)`\n" "before calling `flush`", ) @@ -554,27 +582,29 @@ async def flush( states[key] = acc.get_state() acc.reset() - # Reduce metrics from states for logging if any per-rank backend - if self.logger_backends: - # Use reduce_metrics_states for consistency - reduced_metrics = reduce_metrics_states([states]) + # Reduce and log to PER_RANK_REDUCE backends only (NO_REDUCE backends already logged in push) + if self.per_rank_reduce_backends: + metrics_for_backends = reduce_metrics_states([states]) + + for backend in self.per_rank_reduce_backends: + await backend.log_batch(metrics_for_backends, global_step) - # Log to local logger_backends - for logger_backend in self.logger_backends: - await logger_backend.log(reduced_metrics, global_step) + # Update step (used by NO_REDUCE backends in push) + self.global_step = global_step + 1 return states if return_state else {} async def shutdown(self): """Shutdown logger_backends if initialized.""" + if not self._is_initialized: logger.debug( f"Collector for {get_actor_name_with_rank()} not initialized. Skipping shutdown" ) return - for logger_backend in self.logger_backends: - await logger_backend.finish() + for backend in self.per_rank_reduce_backends + self.per_rank_no_reduce_backends: + await backend.finish() ########### @@ -593,6 +623,7 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + process_name: str | None = None, ) -> None: """ Initializes backend, e.g. wandb.run.init(). @@ -602,19 +633,30 @@ async def init( Can be used to behave differently for primary vs secondary roles. primary_logger_metadata (dict[str, Any] | None): From global backend for backend that required shared info, e.g. {"shared_run_id": "abc123"}. + process_name (str | None): Process name for logging. Raises: ValueError if missing metadata for shared local init. """ pass @abstractmethod - async def log(self, metrics: list[Metric], global_step: int) -> None: - """ - Log a batch of metrics to the backend. + async def log_batch( + self, metrics: list[Metric], global_step: int, *args, **kwargs + ) -> None: + """Log batch of accumulated metrics to backend""" + pass - Args: - metrics: List of Metric objects to log. - global_step: Step number for x-axis alignment across metrics. + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + """Stream single metric to backend immediately. + + NOTE: This method is called synchronously. + If your backend requires async I/O operations: + - Use asyncio.create_task() for fire-and-forget logging + - Consider internal buffering to avoid blocking the caller + + Example for async backend: + def log_stream(self, metric, global_step): + asyncio.create_task(self._async_log(metric, global_step)) """ pass @@ -636,14 +678,13 @@ async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + process_name: str | None = None, ) -> None: - self.prefix = ( - get_actor_name_with_rank() - if self.logger_backend_config.get("reduce_across_ranks", True) - else "Controller" - ) + self.prefix = get_actor_name_with_rank(actor_name=process_name) - async def log(self, metrics: list[Metric], global_step: int) -> None: + async def log_batch( + self, metrics: list[Metric], global_step: int, *args, **kwargs + ) -> None: metrics_str = "\n".join( f" {metric.key}: {metric.value}" for metric in sorted(metrics, key=lambda m: m.key) @@ -652,24 +693,27 @@ async def log(self, metrics: list[Metric], global_step: int) -> None: f"=== [{self.prefix}] - METRICS STEP {global_step} ===\n{metrics_str}\n==============================\n" ) + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + """Stream metric to console immediately.""" + logger.info(f"{metric.key}: {metric.value}") + async def finish(self) -> None: pass class WandbBackend(LoggerBackend): """ - Weights & Biases logging backend for distributed training. + Weights & Biases logging backend. - Supports 3 types of modes as described in https://docs.wandb.ai/guides/track/log/distributed-training/: - Track a single process: reduce_across_ranks=True - Track each process separately: reduce_across_ranks=False, share_run_id=False - Track all processes to a single run: reduce_across_ranks=False, share_run_id=True + For logging mode details, see `forge.observability.metrics.LoggingMode` documentation. + + More details on wandb distributed logging here: https://docs.wandb.ai/guides/track/log/distributed-training/ Configuration: - reduce_across_ranks (bool, default True): If True, log reduced metrics only from controller (global mode). - If False, enables per-rank logging; then use share_run_id to pick mode. - share_run_id (bool, default False): Only used if reduce_across_ranks=False. - True -> shared run across ranks; False -> separate runs per rank. + logging_mode (LoggingMode): Determines logging behavior + per_rank_share_run (bool, default False): For per-rank modes, whether to share run ID across ranks. + If true, then a single wandb is created and all ranks log to it. Its particularly useful if + logging with no_reduce to capture a time based stream of information. Not recommended if reducing values. project (str): WandB project name group (str, optional): WandB group name for organizing runs. Defaults to "experiment_group" """ @@ -680,41 +724,34 @@ def __init__(self, logger_backend_config: dict[str, Any]) -> None: self.group = logger_backend_config.get("group", "experiment_group") self.name = None self.run = None - self.reduce_across_ranks = logger_backend_config.get( - "reduce_across_ranks", True - ) - self.share_run_id = logger_backend_config.get("share_run_id", False) + self.logging_mode = LoggingMode(logger_backend_config["logging_mode"]) + self.per_rank_share_run = logger_backend_config.get("per_rank_share_run", False) async def init( self, role: BackendRole, primary_logger_metadata: dict[str, Any] | None = None, + process_name: str | None = None, ) -> None: if primary_logger_metadata is None: primary_logger_metadata = {} - self.name = ( - get_actor_name_with_rank() - if role == BackendRole.LOCAL - else "global_controller" - ) + self.name = get_actor_name_with_rank(actor_name=process_name) - # Default global mode: only inits on controller - if self.reduce_across_ranks: + # GLOBAL_REDUCE mode: only inits on controller + if self.logging_mode == LoggingMode.GLOBAL_REDUCE: if role != BackendRole.GLOBAL: - logger.debug( - f"Skipped init for global mode (reduce_across_ranks=True) and {role} role." - ) + logger.warning(f"Skipped init for GLOBAL_REDUCE mode and {role} role.") return await self._init_global() - # Per-rank modes based on share_run_id bool - elif role == BackendRole.GLOBAL and self.share_run_id: + # Per-rank modes based on per_rank_share_run bool + elif role == BackendRole.GLOBAL and self.per_rank_share_run: await self._init_shared_global() elif role == BackendRole.LOCAL: - if self.share_run_id: + if self.per_rank_share_run: await self._init_shared_local(primary_logger_metadata) else: await self._init_per_rank() @@ -762,22 +799,38 @@ async def _init_shared_local(self, primary_metadata: dict[str, Any]): settings=settings, ) - async def log(self, metrics: list[Metric], global_step: int) -> None: - if self.run: - # Convert metrics to WandB log format - log_data = {"global_step": global_step} - for metric in metrics: - log_data[metric.key] = metric.value - - self.run.log(log_data) - logger.info( - f"WandbBackend: Logged {len(metrics)} metrics at global_step {global_step}" - ) - else: + async def log_batch( + self, metrics: list[Metric], global_step: int, *args, **kwargs + ) -> None: + if not self.run: logger.debug(f"WandbBackend: No run started, skipping log for {self.name}") + return + + # Convert metrics to WandB log format + log_data = {"step": global_step} + for metric in metrics: + log_data[metric.key] = metric.value + + self.run.log(log_data) + logger.info( + f"WandbBackend: Logged {len(metrics)} metrics at step {global_step}" + ) + + def log_stream(self, metric: Metric, global_step: int, *args, **kwargs) -> None: + """Stream single metric to WandB with both step and timestamp.""" + if not self.run: + return + + # Log with both step and timestamp - users can choose x-axis in WandB UI + log_data = { + metric.key: metric.value, + "global_step": global_step, + "_timestamp": metric.timestamp, + } + self.run.log(log_data) def get_metadata_for_secondary_ranks(self) -> dict[str, Any]: - if self.run and not self.reduce_across_ranks and self.share_run_id: + if self.run and self.per_rank_share_run: return {"shared_run_id": self.run.id} return {} diff --git a/src/forge/observability/utils.py b/src/forge/observability/utils.py new file mode 100644 index 000000000..f9fc18014 --- /dev/null +++ b/src/forge/observability/utils.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Optional + +from monarch.actor import context, current_rank + +logger = logging.getLogger(__name__) + + +def detect_actor_name_from_call_stack() -> str: + """Detect ForgeActor subclass name from call stack. + + Returns: + str: Actor name, defaulting to "UnknownActor" if not found. + """ + try: + import inspect + + frame = inspect.currentframe() + frame_count = 0 + + while frame: + frame = frame.f_back + if not frame: + break + + frame_count += 1 + if frame_count > 20: # Prevent infinite loops + break + + # Check for 'self' (instance method calls) + if "self" in frame.f_locals: + obj = frame.f_locals["self"] + if hasattr(obj, "__class__") and hasattr(obj.__class__, "__mro__"): + for base in obj.__class__.__mro__: + if base.__name__ == "ForgeActor": + return obj.__class__.__name__ + + # Check for 'cls' (class method calls) + if "cls" in frame.f_locals: + cls = frame.f_locals["cls"] + if hasattr(cls, "__mro__"): + for base in cls.__mro__: + if base.__name__ == "ForgeActor": + return cls.__name__ + + except Exception as e: + logger.debug(f"Call stack detection failed: {e}") + + return "UnknownActor" + + +def get_actor_name_with_rank(actor_name: Optional[str] = None) -> str: + """ + Extracts actor information from Monarch context to form a logging name. + + Args: + actor_name: Optional actor name to use. If None, will auto-detect from call stack. + + Returns: + str: Format "ActorName_replicaId_rLocalRank" (e.g., "TrainActor_abcd_r0"). + Falls back to "UnknownActor" if context unavailable. + """ + ctx = context() + if ctx is None or ctx.actor_instance is None: + logger.warning("Context unavailable, using fallback actor name for logging.") + return "UnknownActor" + + actor_instance = ctx.actor_instance + rank = current_rank() + actor_id_full = str(actor_instance.actor_id) + + # Parse the actor_id + parts = actor_id_full.split(".") + if len(parts) < 2: + return "UnknownActor" + + world_part = parts[0] # e.g., "_1rjutFUXQrEJ[0]" + actor_part = parts[1] # e.g., "TestActorConfigured[0]" + + # Use provided actor name or auto-detect from call stack + if actor_name: + final_actor_name = actor_name + else: + final_actor_name = detect_actor_name_from_call_stack() + + # Use last 4 characters of world_id as replica identifier + world_id = world_part.split("[")[0] if "[" in world_part else world_part + replica_id = world_id[-4:] if len(world_id) >= 4 else world_id + + return f"{final_actor_name}_{replica_id}_r{rank.rank}" diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index d999fb700..f61165f42 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -82,20 +82,18 @@ async def main(): group = f"grpo_exp_{int(time.time())}" # Config format: {backend_name: backend_config_dict} - # Each backend can specify reduce_across_ranks to control distributed logging behavior config = { - "console": {"reduce_across_ranks": True}, + "console": {"logging_mode": "global_reduce"}, "wandb": { - "project": "my_project", + "project": "toy_metrics", "group": group, - "reduce_across_ranks": False, - # Only useful if NOT reduce_across_ranks. - "share_run_id": False, # Share run ID across ranks -- Not recommended. + "logging_mode": "per_rank_reduce", # global_reduce, per_rank_reduce, per_rank_no_reduce + "per_rank_share_run": True, }, } service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False} - mlogger = await get_or_create_metric_logger() + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(config) # Spawn services first (triggers registrations via provisioner hook) diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 54b093841..ba8886621 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -32,8 +32,10 @@ async def run(cfg: DictConfig): await init_provisioner( ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner)) ) - metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}}) - mlogger = await get_or_create_metric_logger() + metric_logging_cfg = cfg.get( + "metric_logging", {"console": {"logging_mode": "global_reduce"}} + ) + mlogger = await get_or_create_metric_logger(process_name="Controller") await mlogger.init_backends.call_one(metric_logging_cfg) if (prompt := cfg.get("prompt")) is None: diff --git a/tests/unit_tests/datasets/test_hf.py b/tests/unit_tests/datasets/test_hf.py index c1535c8b8..b115e4a7e 100644 --- a/tests/unit_tests/datasets/test_hf.py +++ b/tests/unit_tests/datasets/test_hf.py @@ -26,8 +26,8 @@ import pytest import torch.distributed as dist -from forge.data.dataset_metrics import DefaultTrainingMetricTransform, MetricsAggregator from forge.data.datasets import HfIterableDataset +from forge.data.metric_transform import DefaultDatasetMetricTransform from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader @@ -93,7 +93,7 @@ def _create_dataset( dataset_name=dataset_name, seed=SEED, shuffle_buffer_size=10 if shuffle else 0, - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=2, **kwargs, ) @@ -113,7 +113,7 @@ def test_default_dataset_name(self, small_dataset_file): split="train", # dataset_name not provided - should auto-generate seed=SEED, - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=None, # Now using new observability system num_shards_per_rank=4, ) @@ -131,7 +131,7 @@ def test_default_dataset_name(self, small_dataset_file): dataset_name="my_dataset", weight=custom_weight, seed=SEED, - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=None, # Now using new observability system num_shards_per_rank=4, ) @@ -149,17 +149,16 @@ def test_epoch_boundaries_and_checkpointing( the epoch metric is correct, and checkpointing works as expected. """ - # 1. Setup Dataloaders and Aggregators for original and resumed runs - def create_loader_and_aggregator(): + # 1. Setup Dataloaders for original and resumed runs + def create_loader(): dataset = dataset_factory(small_dataset_file, shuffle=False) loader = StatefulDataLoader( dataset, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics ) - aggregator = MetricsAggregator() - return loader, aggregator + return loader - loader1, aggregator1 = create_loader_and_aggregator() - loader2, aggregator2 = create_loader_and_aggregator() + loader1 = create_loader() + loader2 = create_loader() # 2. Calculate steps for the test run total_samples = int(SMALL_DATASET_SIZE * num_epochs) @@ -171,11 +170,9 @@ def create_loader_and_aggregator(): # 3. Generate checkpoint and resume result = generate_ckpt( loader1, - aggregator1, steps_before_checkpoint=steps_before_checkpoint, steps_after_checkpoint=steps_after_checkpoint, resume_dataloader=loader2, - resume_aggregator=aggregator2, ) # 4. Verify checkpointing and resumption @@ -253,9 +250,7 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): for sample in first_epoch_samples: first_epoch_metrics.extend(sample["metrics"]) epoch_values = [ - metric.value - for metric in first_epoch_metrics - if metric.metric_name == "epoch" + metric.value for metric in first_epoch_metrics if "num_epochs" in metric.key ] assert all( epoch_value == 0 for epoch_value in epoch_values @@ -268,7 +263,7 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): epoch_values = [ metric.value for metric in second_epoch_metrics - if metric.metric_name == "epoch" + if "num_epochs" in metric.key ] assert all( epoch_value == 1 for epoch_value in epoch_values @@ -314,7 +309,7 @@ def test_distributed_epoch_boundary_checkpointing(self): # Test multiple epoch boundaries for num_epochs in [0.9, 1.0, 2.5]: - def create_loader_and_aggregator(): + def create_loader(): dataset = HfIterableDataset( path="json", data_files=str(medium_dataset_file), @@ -322,7 +317,7 @@ def create_loader_and_aggregator(): dataset_name="epoch_test", seed=SEED, shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=None, # Now using new observability system num_shards_per_rank=2, ) loader = StatefulDataLoader( @@ -331,10 +326,10 @@ def create_loader_and_aggregator(): collate_fn=collate_with_metrics, num_workers=0, ) - return loader, MetricsAggregator() + return loader - loader1, aggregator1 = create_loader_and_aggregator() - loader2, aggregator2 = create_loader_and_aggregator() + loader1 = create_loader() + loader2 = create_loader() # Calculate steps to reach desired epoch boundary samples_per_rank = MEDIUM_DATASET_SIZE // dist.get_world_size() @@ -352,11 +347,9 @@ def create_loader_and_aggregator(): result = generate_ckpt( loader1, - aggregator1, - steps_before, - steps_after, + steps_before_checkpoint=steps_before, + steps_after_checkpoint=steps_after, resume_dataloader=loader2, - resume_aggregator=aggregator2, ) # Verify deterministic resumption - critical for distributed training @@ -375,7 +368,7 @@ def create_loader_and_aggregator(): num_epochs - 1e-9 ) # -1e-9 so 1.0 epochs -> 0 assert ( - final_metrics["train_epoch_test/num_epochs"] == expected_epoch + final_metrics["dataset/epoch_test/num_epochs"] == expected_epoch ), f"Epoch count incorrect for {num_epochs} epochs test scenario" finally: diff --git a/tests/unit_tests/datasets/test_interleaved.py b/tests/unit_tests/datasets/test_interleaved.py index 0073b905e..f7b616b7a 100644 --- a/tests/unit_tests/datasets/test_interleaved.py +++ b/tests/unit_tests/datasets/test_interleaved.py @@ -28,9 +28,9 @@ import torch import torch.distributed as dist - -from forge.data.dataset_metrics import DefaultTrainingMetricTransform, MetricsAggregator from forge.data.datasets import HfIterableDataset, InterleavedDataset + +from forge.data.metric_transform import DefaultDatasetMetricTransform from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader @@ -114,7 +114,7 @@ def _create_dataset( dataset_name=dataset_name, seed=SEED, shuffle_buffer_size=10 if shuffle else 0, - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=2, **kwargs, ) @@ -299,37 +299,47 @@ def test_metrics_aggregation( [child_interleaved, ds3], seed=SEED, dataset_name="parent" ) - aggregator = MetricsAggregator() + # Collect metrics manually instead of using old MetricsAggregator + collected_metrics = [] # Process some samples total_samples = 300 for sample in islice(iter(parent_interleaved), total_samples): - aggregator.update(sample["metrics"]) - - metrics = aggregator.get_metrics_for_logging(prefix="train") - - # Should have metrics from all three datasets, with flat keys - assert "train_ds1/samples_seen" in metrics - assert "train_ds2/samples_seen" in metrics - assert "train_ds3/samples_seen" in metrics + if "metrics" in sample: + collected_metrics.extend(sample["metrics"]) + + # Count metrics by dataset name (using new metric key) + ds1_samples_processed = sum( + 1 + for m in collected_metrics + if hasattr(m, "key") and "dataset/ds1/samples_processed" in m.key + ) + ds2_samples_processed = sum( + 1 + for m in collected_metrics + if hasattr(m, "key") and "dataset/ds2/samples_processed" in m.key + ) + ds3_samples_processed = sum( + 1 + for m in collected_metrics + if hasattr(m, "key") and "dataset/ds3/samples_processed" in m.key + ) # All datasets should have contributed samples - assert metrics["train_ds1/samples_seen"] > 0 - assert metrics["train_ds2/samples_seen"] > 0 - assert metrics["train_ds3/samples_seen"] > 0 + assert ds1_samples_processed > 0, "ds1 should have contributed samples" + assert ds2_samples_processed > 0, "ds2 should have contributed samples" + assert ds3_samples_processed > 0, "ds3 should have contributed samples" # Total samples should equal what we processed calculated_total_samples = ( - metrics["train_ds1/samples_seen"] - + metrics["train_ds2/samples_seen"] - + metrics["train_ds3/samples_seen"] + ds1_samples_processed + ds2_samples_processed + ds3_samples_processed ) assert calculated_total_samples == total_samples # Test that ratios are approximately correct based on nested weighting - ds1_ratio = metrics["train_ds1/samples_seen"] / total_samples - ds2_ratio = metrics["train_ds2/samples_seen"] / total_samples - ds3_ratio = metrics["train_ds3/samples_seen"] / total_samples + ds1_ratio = ds1_samples_processed / total_samples + ds2_ratio = ds2_samples_processed / total_samples + ds3_ratio = ds3_samples_processed / total_samples # Expected ratios based on nested weighting: # Inner weights: ds1=0.2, ds2=0.8 -> inner total=1.0 @@ -377,22 +387,18 @@ def create_interleaved(): loader1 = StatefulDataLoader( interleaved1, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics ) - aggregator1 = MetricsAggregator() # Resumed run interleaved2 = create_interleaved() loader2 = StatefulDataLoader( interleaved2, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics ) - aggregator2 = MetricsAggregator() result = generate_ckpt( loader1, - aggregator1, steps_before_checkpoint=10, steps_after_checkpoint=20, resume_dataloader=loader2, - resume_aggregator=aggregator2, ) orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] @@ -512,7 +518,7 @@ def create_dataset(): split="train", dataset_name="ds1", shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=2, weight=0.3, ) @@ -522,7 +528,7 @@ def create_dataset(): split="train", dataset_name="ds2", shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=2, weight=0.7, ) @@ -532,7 +538,7 @@ def create_dataset(): split="train", dataset_name="ds3", shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=2, weight=1.0, ) @@ -552,19 +558,17 @@ def create_dataloader(dataset): num_workers=0, # Avoid multiprocessing in distributed tests collate_fn=collate_with_metrics, ) - return loader, MetricsAggregator() + return loader # Run checkpointing test with small number of steps - loader1, aggregator1 = create_dataloader(create_dataset()) - loader2, aggregator2 = create_dataloader(create_dataset()) + loader1 = create_dataloader(create_dataset()) + loader2 = create_dataloader(create_dataset()) result = generate_ckpt( loader1, - aggregator1, - 3, - 3, # 3 steps before, 3 steps after checkpoint + steps_before_checkpoint=3, + steps_after_checkpoint=3, resume_dataloader=loader2, - resume_aggregator=aggregator2, ) # Verify deterministic resumption diff --git a/tests/unit_tests/datasets/test_iterable_utils.py b/tests/unit_tests/datasets/test_iterable_utils.py index cdeced7c7..95bfc057b 100644 --- a/tests/unit_tests/datasets/test_iterable_utils.py +++ b/tests/unit_tests/datasets/test_iterable_utils.py @@ -7,92 +7,91 @@ from typing import Any, Optional import torch -from forge.data.dataset_metrics import MetricsAggregator - from torch.utils.data import DataLoader -def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: - """Simple collate that extracts metrics and pads tokens.""" - all_metrics = [] - clean_batch = [] +def collate_with_metrics(batch): + """ + Simple collate function that preserves metrics for validation. + Collects metrics from all samples in the batch and aggregates them. + + Uses a simple collation that doesn't enforce same sizes for lists/tokens. + """ + # Collect metrics from all samples + batch_metrics = [] for sample in batch: if "metrics" in sample: - all_metrics.extend(sample.pop("metrics")) - clean_batch.append(sample) - - if not clean_batch: - return {"metrics": all_metrics} - - # Simple padding for tokens - ids = torch.tensor([item["id"] for item in clean_batch]) - tokens = torch.nn.utils.rnn.pad_sequence( - [torch.tensor(item["tokens"]) for item in clean_batch], - batch_first=True, - padding_value=-1, # Use -1 for padding to distinguish from valid IDs - ) - collated = { - "id": ids, - "tokens": tokens, - } - - # Add text field for non-tensor data - if "text" in clean_batch[0]: - collated["text"] = [item["text"] for item in clean_batch] + batch_metrics.extend(sample.pop("metrics")) + + # Simple collation that handles variable-length sequences + collated = {} + if batch: + for key in batch[0].keys(): + values = [sample[key] for sample in batch] + if key == "tokens" or key == "labels": + # Keep as list of lists for variable length sequences + collated[key] = values + else: + # Use default collation for scalars + collated[key] = torch.utils.data.default_collate(values) + + # Add batch-level metrics key for downstream processing + if batch_metrics: + collated["metrics"] = batch_metrics - collated["metrics"] = all_metrics return collated def generate_ckpt( dataloader: DataLoader, - aggregator: MetricsAggregator, steps_before_checkpoint: int, steps_after_checkpoint: int, resume_dataloader: Optional[DataLoader] = None, - resume_aggregator: Optional[MetricsAggregator] = None, ) -> dict[str, Any]: """ Generates a checkpoint by running through data and saving checkpoint mid-stream. - Optionally, a second dataloader and aggregator can be given to resume from ckpt + Optionally, a second dataloader can be given to resume from checkpoint and run steps_after_checkpoint to match the first one. + Collects and aggregates metrics for test validation purposes. + Args: dataloader (DataLoader): The dataloader to test - aggregator (MetricsAggregator): The metrics aggregator to use steps_before_checkpoint (int): Number of steps to run before saving checkpoint steps_after_checkpoint (int): Number of steps to run after checkpoint resume_dataloader (Optional[DataLoader]): Optional new dataloader to test resuming. If None, returns empty resumed_batches. - resume_aggregator (Optional[MetricsAggregator]): Optional new aggregator to test resuming. - If None, returns empty resumed_metrics. Returns: - dict[str, Any]: Dict with batches/metrics from both pre and post checkpoint runs. + dict[str, Any]: Dict with batches and aggregated metrics for validation. """ iterator = iter(dataloader) - # Collect batches before and after checkpoint + # Collect batches and metrics before and after checkpoint batches = [] + all_metrics = [] # All metrics collected during the run + checkpoint_metrics = [] # Metrics collected only up to checkpoint checkpoint_state = None - metrics_at_checkpoint = {} total_steps = steps_before_checkpoint + steps_after_checkpoint for idx, batch in enumerate(iterator): batches.append(batch) - # Process metrics + # Collect metrics for test validation if "metrics" in batch: - aggregator.update(batch.pop("metrics")) + batch_metrics = batch.pop("metrics") + all_metrics.extend(batch_metrics) + + # If we haven't reached checkpoint yet, also add to checkpoint metrics + if idx < steps_before_checkpoint: + checkpoint_metrics.extend(batch_metrics) # Save checkpoint state after steps_before_checkpoint if idx == steps_before_checkpoint - 1: # -1 because idx is 0-based checkpoint_state = { "loader": dataloader.state_dict(), - "aggregator": aggregator.state_dict(), } - metrics_at_checkpoint = aggregator.get_metrics_for_logging(prefix="train") # Stop after total steps if idx == total_steps - 1: @@ -102,43 +101,50 @@ def generate_ckpt( pre_checkpoint_batches = batches[:steps_before_checkpoint] post_checkpoint_batches = batches[steps_before_checkpoint:] - # Resume with new instances if provided + # Resume with new instance if provided resumed_batches = [] - resumed_metrics = {} - - if ( - resume_dataloader is not None - and resume_aggregator is not None - and checkpoint_state is not None - ): - # Test resuming with new instances + resumed_metrics = [] + + if resume_dataloader is not None and checkpoint_state is not None: + # Test resuming with new instance resume_dataloader.load_state_dict(checkpoint_state["loader"]) - resume_aggregator.load_state_dict(checkpoint_state["aggregator"]) resume_iterator = iter(resume_dataloader) # Collect only the post-checkpoint batches when resuming for idx, batch in enumerate(resume_iterator): resumed_batches.append(batch) - # Process metrics + # Collect metrics from resumed batches if "metrics" in batch: - resume_aggregator.update(batch.pop("metrics")) + batch_metrics = batch.pop("metrics") + resumed_metrics.extend(batch_metrics) # Stop after steps_after_checkpoint if idx == steps_after_checkpoint - 1: break - resumed_metrics = resume_aggregator.get_metrics_for_logging(prefix="train") - return { # Original run "pre_checkpoint_batches": pre_checkpoint_batches, "post_checkpoint_batches": post_checkpoint_batches, - "metrics_at_checkpoint": metrics_at_checkpoint, - "final_metrics": aggregator.get_metrics_for_logging(prefix="train"), + "metrics_at_checkpoint": keep_last_metric(checkpoint_metrics), + "final_metrics": keep_last_metric(all_metrics), # Resumed run "resumed_batches": resumed_batches, - "resumed_metrics": resumed_metrics, + "resumed_metrics": keep_last_metric(resumed_metrics), # Internal state for loading - only if someone needs to manually load "_checkpoint_state": checkpoint_state, } + + +def keep_last_metric(metrics_list: list) -> dict[str, Any]: + result = {} + for metric in metrics_list: + # Expect observability.Metric objects only + key = metric.key + value = metric.value + + # For test purposes, just keep the last value of each metric + result[key] = value + + return result diff --git a/tests/unit_tests/datasets/test_packed.py b/tests/unit_tests/datasets/test_packed.py index 352fbf703..56cd5ff02 100644 --- a/tests/unit_tests/datasets/test_packed.py +++ b/tests/unit_tests/datasets/test_packed.py @@ -14,7 +14,6 @@ import torch from forge.data.collate import collate_packed -from forge.data.dataset_metrics import MetricsAggregator from forge.data.datasets import HfIterableDataset from forge.data.datasets.packed import ( _SUPPORTS_FLEX_ATTENTION, @@ -914,7 +913,7 @@ def test_checkpoint_and_resume(self, dataset_factory): batch_size = 1 # Setup dataset factory - def create_loader_and_aggregator(): + def create_loader(): dataset = dataset_factory(samples) packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = PackedDataset( @@ -931,11 +930,10 @@ def create_loader_and_aggregator(): loader = StatefulDataLoader( packed_dataset, batch_size=batch_size, collate_fn=collate_fn ) - aggregator = MetricsAggregator() - return loader, aggregator + return loader - loader1, aggregator1 = create_loader_and_aggregator() - loader2, aggregator2 = create_loader_and_aggregator() + loader1 = create_loader() + loader2 = create_loader() steps_before_checkpoint = 2 steps_after_checkpoint = 2 @@ -943,11 +941,9 @@ def create_loader_and_aggregator(): # Generate checkpoint and resume result = generate_ckpt( loader1, - aggregator1, steps_before_checkpoint=steps_before_checkpoint, steps_after_checkpoint=steps_after_checkpoint, resume_dataloader=loader2, - resume_aggregator=aggregator2, ) # Verify that checkpointing and resumption work diff --git a/tests/unit_tests/observability/conftest.py b/tests/unit_tests/observability/conftest.py index e8900392c..e35350c11 100644 --- a/tests/unit_tests/observability/conftest.py +++ b/tests/unit_tests/observability/conftest.py @@ -22,13 +22,14 @@ def __init__(self, logger_backend_config=None): self.finish_called = False self.metadata = {} - async def init(self, role="local", primary_logger_metadata=None): + async def init(self, role="local", primary_logger_metadata=None, process_name=None): self.init_called = True self.role = role self.primary_logger_metadata = primary_logger_metadata or {} + self.process_name = process_name - async def log(self, metrics, step): - self.logged_metrics.append((metrics, step)) + async def log(self, metrics, global_step): + self.logged_metrics.append((metrics, global_step)) async def finish(self): self.finish_called = True diff --git a/tests/unit_tests/observability/test_metric_actors.py b/tests/unit_tests/observability/test_metric_actors.py new file mode 100644 index 000000000..8a15d4497 --- /dev/null +++ b/tests/unit_tests/observability/test_metric_actors.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Optimized unit tests for metric actors functionality.""" + +from unittest.mock import patch + +import pytest + +from forge.observability.metric_actors import ( + get_or_create_metric_logger, + GlobalLoggingActor, + LocalFetcherActor, +) + +from forge.observability.metrics import LoggingMode +from monarch.actor import this_host + + +@pytest.fixture +def global_logger(): + """Create a GlobalLoggingActor for testing.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestGlobalLogger", GlobalLoggingActor) + + +@pytest.fixture +def local_fetcher(global_logger): + """Create a LocalFetcherActor linked to global logger.""" + p = this_host().spawn_procs(per_host={"cpus": 1}) + return p.spawn("TestLocalFetcher", LocalFetcherActor, global_logger) + + +class TestBasicOperations: + """Test basic operations for actors.""" + + @pytest.mark.asyncio + async def test_local_fetcher_flush(self, local_fetcher): + """Test LocalFetcherActor flush operations.""" + result_with_state = await local_fetcher.flush.call_one( + global_step=1, return_state=True + ) + assert result_with_state == {} + + result_without_state = await local_fetcher.flush.call_one( + global_step=1, return_state=False + ) + assert result_without_state == {} + + @pytest.mark.asyncio + async def test_global_logger_basic_ops(self, global_logger): + """Test GlobalLoggingActor basic operations.""" + count = await global_logger.get_fetcher_count.call_one() + assert count >= 0 + + has_fetcher = await global_logger.has_fetcher.call_one("nonexistent") + assert has_fetcher is False + + # Global logger flush (should not raise error) + await global_logger.flush.call_one(global_step=1) + + @pytest.mark.asyncio + async def test_backend_init(self, global_logger): + """Test backend initialization and shutdown through proper validation flow.""" + # Use global_logger to ensure proper validation flow (string -> enum conversion) + config = {"console": {"logging_mode": "per_rank_reduce"}} + + await global_logger.init_backends.call_one(config) + await global_logger.shutdown.call_one() + + +class TestRegistrationLifecycle: + """Test registration lifecycle.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_registration_lifecycle(self, global_logger, local_fetcher): + """Test complete registration/deregistration lifecycle.""" + proc_name = "lifecycle_test_proc" + + # Initial state + initial_count = await global_logger.get_fetcher_count.call_one() + assert await global_logger.has_fetcher.call_one(proc_name) is False + + # Register + await global_logger.register_fetcher.call_one(local_fetcher, proc_name) + + # Verify registered + new_count = await global_logger.get_fetcher_count.call_one() + assert new_count == initial_count + 1 + assert await global_logger.has_fetcher.call_one(proc_name) is True + + # Deregister + await global_logger.deregister_fetcher.call_one(proc_name) + + # Verify deregistered + final_count = await global_logger.get_fetcher_count.call_one() + assert final_count == initial_count + assert await global_logger.has_fetcher.call_one(proc_name) is False + + +class TestBackendConfiguration: + """Test backend configuration validation.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_valid_backend_configs(self, global_logger): + """Test valid backend configurations.""" + # Empty config + await global_logger.init_backends.call_one({}) + + # Valid configs for all logging modes + for mode in ["per_rank_reduce", "per_rank_no_reduce", "global_reduce"]: + config = {"console": {"logging_mode": mode}} + await global_logger.init_backends.call_one(config) + + def test_invalid_backend_configs(self): + """Test invalid backend configurations and warnings using direct validation.""" + actor = GlobalLoggingActor() + + # Test 1: Invalid logging_mode should raise ValueError + with pytest.raises(ValueError, match="is not a valid LoggingMode"): + actor._validate_backend_config("console", {"logging_mode": "invalid_mode"}) + + # Test 2: WandB PER_RANK_REDUCE + per_rank_share_run=True should warn + with patch("forge.observability.metric_actors.logger.warning") as mock_warn: + config = { + "logging_mode": "per_rank_reduce", + "per_rank_share_run": True, + "project": "test_project", + } + + result = actor._validate_backend_config("wandb", config) + + # Should have logged warning about suboptimal config + mock_warn.assert_called_once() + warning_msg = str(mock_warn.call_args) + assert "not recommended" in warning_msg + + # Should still return valid config with LoggingMode enum + assert result["logging_mode"] == LoggingMode.PER_RANK_REDUCE + assert result["per_rank_share_run"] is True + assert result["project"] == "test_project" + + +class TestErrorHandling: + """Test error handling scenarios.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_deregister_nonexistent_fetcher(self, global_logger): + """Test deregistering non-existent fetcher doesn't crash.""" + await global_logger.deregister_fetcher.call_one("nonexistent_proc") + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_shutdown(self, global_logger): + """Test shutdown without issues.""" + await global_logger.shutdown.call_one() + + +class TestGetOrCreateMetricLogger: + """Test the integration function.""" + + @pytest.mark.timeout(3) + @pytest.mark.asyncio + async def test_get_or_create_functionality(self): + """Test get_or_create_metric_logger basic functionality.""" + result = await get_or_create_metric_logger() + + # Should return a GlobalLoggingActor mesh + assert result is not None + + # Should be able to call basic methods + count = await result.get_fetcher_count.call_one() + assert count >= 0 diff --git a/tests/unit_tests/observability/test_metrics.py b/tests/unit_tests/observability/test_metrics.py index 701bda2dc..4bb530eb3 100644 --- a/tests/unit_tests/observability/test_metrics.py +++ b/tests/unit_tests/observability/test_metrics.py @@ -16,6 +16,7 @@ BackendRole, ConsoleBackend, get_logger_backend_class, + LoggingMode, MaxAccumulator, MeanAccumulator, Metric, @@ -80,18 +81,17 @@ def test_new_enums_and_constants(self): assert isinstance(BackendRole.LOCAL, BackendRole) assert isinstance(BackendRole.GLOBAL, BackendRole) - @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_backend_role_usage(self, mock_actor_name): + async def test_backend_role_usage(self): """Test that BackendRole constants are actually used instead of string literals.""" - mock_actor_name.return_value = "TestActor_abcd_r0" - # Test ConsoleBackend console_backend = ConsoleBackend({}) await console_backend.init(role=BackendRole.LOCAL) # Test WandbBackend role validation without WandB initialization - wandb_backend = WandbBackend({"project": "test"}) + wandb_backend = WandbBackend( + {"project": "test", "logging_mode": "global_reduce"} + ) # Mock all the WandB init methods to focus only on role validation with patch.object(wandb_backend, "_init_global"), patch.object( @@ -295,41 +295,36 @@ def test_record_metric_enabled_explicit(self, mock_collector_class, mock_rank): mock_collector_class.assert_called_once() mock_collector.push.assert_called_once() - @patch("forge.observability.metrics.get_actor_name_with_rank") - def test_wandb_backend_creation(self, mock_actor_name): + def test_wandb_backend_creation(self): """Test WandbBackend creation and basic setup without WandB dependency.""" - mock_actor_name.return_value = "TestActor_abcd_r0" config = { "project": "test_project", "group": "test_group", - "reduce_across_ranks": True, + "logging_mode": "global_reduce", } backend = WandbBackend(config) assert backend.project == "test_project" assert backend.group == "test_group" - assert backend.reduce_across_ranks is True - assert backend.share_run_id is False # default + assert backend.logging_mode == LoggingMode.GLOBAL_REDUCE + assert backend.per_rank_share_run is False # default # Test metadata method metadata = backend.get_metadata_for_secondary_ranks() assert metadata == {} # Should be empty when no run - @patch("forge.observability.metrics.get_actor_name_with_rank") @pytest.mark.asyncio - async def test_console_backend(self, mock_actor_name): + async def test_console_backend(self): """Test ConsoleBackend basic operations.""" - mock_actor_name.return_value = "TestActor_abcd_r0" - backend = ConsoleBackend({}) await backend.init(role=BackendRole.LOCAL) - # Test log - should not raise + # Test log_batch - should not raise # Create a test metric test_metric = Metric("test", 1.0, Reduce.MEAN) - await backend.log([test_metric], global_step=1) + await backend.log_batch([test_metric], global_step=1) await backend.finish() # Should not raise