diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index e9ddc625a..44e4485e4 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -46,7 +46,7 @@ parallelism: checkpoint: enable: true folder: ./checkpoint # The folder to save checkpoints to. - initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo last_save_in_hf: true interval: 500 @@ -56,12 +56,6 @@ activation_checkpoint: mode: selective selective_ac_option: op -metric_logging: - wandb: - project: sft-training - group: sft_exp_${oc.env:USER} - logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce - # profiling: # enable_profiling: false diff --git a/apps/sft/main.py b/apps/sft/main.py index edda0b49d..aa484608e 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -27,7 +27,6 @@ from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer -from forge.observability import get_or_create_metric_logger, record_metric, Reduce from forge.util.config import parse from monarch.actor import current_rank, current_size, endpoint @@ -78,6 +77,7 @@ def __init__(self, config: DictConfig): self.current_step = 0 self.num_training_steps = job_config.training.steps + self.metric_logger = None # TODO: fix this self.gradient_accumulation_steps = 1 # Example value, adjust as needed self._rank = current_rank().rank self._size = math.prod(current_size().values()) @@ -109,22 +109,9 @@ def _init_dist(self): os.environ.update(env) logger.info("env: {}".format(env)) - async def setup_metric_logger(self): - """Initialization happens in the main process. Here we just retrieve it""" - mlogger = await get_or_create_metric_logger() - return mlogger - - def record_batch_metrics(self, data_metrics: list): - """Since the dataloader creates new processes, we dont call `record_metric` in the dataset. - Instead, pop the metrics from the batch and record them here.""" - for metric in data_metrics: - record_metric(metric.key, metric.value, metric.reduction) - @endpoint async def setup(self): self.train_dataloader = self.setup_data() - self.mlogger = await self.setup_metric_logger() - # self.train_dataloader = self.setup_data( # self.train_config.train_dataset_config, # self.train_config.train_dataloader_config, @@ -247,9 +234,7 @@ def train_step(self, batch) -> None: # ) as grad_acc: labels = batch.pop("labels") loss = self.forward_backward(batch, labels) - loss = loss.item() - record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN) logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") # self.pbar.update(1) @@ -266,25 +251,14 @@ async def train(self) -> None: while self.current_step < self.num_training_steps: batch = next(dataloader) - - # Pop and record metrics from batch before moving to device - self.record_batch_metrics(batch.pop("metrics", [])) - record_metric("ForgeSFTRecipe/train/step", self.current_step, Reduce.MEAN) - # Move tensors to the appropriate device for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to("cuda") # TODO: hardcoded for now - self.train_step(batch) # self.profiler.step() self.current_step += 1 - # Flush metrics - if self._rank == 0: - logger.debug(f"Flushing metrics at step {self.current_step}") - await self.mlogger.flush.call_one(global_step=self.current_step) - self.checkpointer.save( curr_step=self.current_step, last_step=self.current_step == self.num_training_steps, @@ -296,23 +270,16 @@ async def train(self) -> None: async def cleanup(self) -> None: if self.checkpointer: self.checkpointer.close() - if getattr(self, "mlogger", None): - await self.mlogger.shutdown.call_one() + if self.metric_logger: + self.metric_logger.close() def __repr__(self) -> str: return "Trainer" async def run(cfg: DictConfig) -> None: - - logging.info("Spawning recipe...") + logging.info("Spawing recipe...") process_cfg = cfg.pop("processes") - - # Initialize metric logger in main process - metric_logging_cfg = cfg.get("metric_logging", {}) - mlogger = await get_or_create_metric_logger(process_name="Controller") - await mlogger.init_backends.call_one(metric_logging_cfg) - recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg) logging.info("Created recipe, running setup.") @@ -323,7 +290,6 @@ async def run(cfg: DictConfig) -> None: logging.info("Done training. Clean up") await recipe.cleanup.call() - await recipe.mesh.stop() logging.info("All done!") diff --git a/apps/sft/qwen3_8b.yaml b/apps/sft/qwen3_8b.yaml index f7c4999bb..1c0d5bc8b 100644 --- a/apps/sft/qwen3_8b.yaml +++ b/apps/sft/qwen3_8b.yaml @@ -45,7 +45,7 @@ parallelism: checkpoint: enable: true folder: ./checkpoint # The folder to save checkpoints to. - initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo last_save_in_hf: true interval: 500 @@ -55,12 +55,6 @@ activation_checkpoint: mode: selective selective_ac_option: op -metric_logging: - wandb: - project: sft-training - group: sft_exp_${oc.env:USER} - logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce - # profiling: # enable_profiling: false diff --git a/src/forge/data/__init__.py b/src/forge/data/__init__.py index 74ba663e0..4347199b9 100644 --- a/src/forge/data/__init__.py +++ b/src/forge/data/__init__.py @@ -4,12 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from .collate import collate_packed -from .metric_transform import DefaultDatasetMetricTransform, MetricTransform from .utils import CROSS_ENTROPY_IGNORE_IDX -__all__ = [ - "collate_packed", - "CROSS_ENTROPY_IGNORE_IDX", - "MetricTransform", - "DefaultDatasetMetricTransform", -] +__all__ = ["collate_packed", "CROSS_ENTROPY_IGNORE_IDX"] diff --git a/src/forge/data/dataset_metrics/__init__.py b/src/forge/data/dataset_metrics/__init__.py new file mode 100644 index 000000000..3a218e282 --- /dev/null +++ b/src/forge/data/dataset_metrics/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .metric_agg_handlers import ( + AggregationHandler, + CategoricalCountAggHandler, + MaxAggHandler, + MeanAggHandler, + MetricState, + MinAggHandler, + StatsAggHandler, + SumAggHandler, +) +from .metric_aggregator import MetricsAggregator +from .metric_transform import ( + AggregationType, + DefaultTrainingMetricTransform, + Metric, + MetricTransform, +) + +__all__ = [ + "AggregationType", + "AggregationHandler", + "CategoricalCountAggHandler", + "DefaultTrainingMetricTransform", + "StatsAggHandler", + "MaxAggHandler", + "MeanAggHandler", + "Metric", + "MetricState", + "MetricsAggregator", + "MetricTransform", + "MinAggHandler", + "SumAggHandler", +] diff --git a/src/forge/data/dataset_metrics/metric_agg_handlers.py b/src/forge/data/dataset_metrics/metric_agg_handlers.py new file mode 100644 index 000000000..bb3978a6b --- /dev/null +++ b/src/forge/data/dataset_metrics/metric_agg_handlers.py @@ -0,0 +1,466 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from abc import ABC, abstractmethod +from collections import Counter, deque +from dataclasses import dataclass, field +from typing import Any + +import torch + +from .metric_transform import AggregationType, Metric + +logger = logging.getLogger(__name__) + + +@dataclass +class MetricState: + """Mutable state object representing the state of a (source, metric_name) on a single rank. + + Attributes: + source (str): Name of the source, e.g. the dataset name. Used for logging and disambiguation. + metric_name (str): Name of the metric. + value (float): Current aggregated value, whose meaning depends on the aggregation type + (e.g., running sum, current max). + agg_type (AggregationType): Aggregation type. + metadata (dict[str, Any]): Additional state like count, list of values, etc. + """ + + source: str + metric_name: str + value: float + agg_type: AggregationType + metadata: dict[str, Any] = field(default_factory=dict) + + +class AggregationHandler(ABC): + """Base class for handling metric aggregation. + + Each handler implements a specific aggregation strategy (SUM, MEAN, STATS, etc.) + and manages the complete lifecycle: initialization, updates, local finalization, + and distributed reduction. Handlers also handle serialization for checkpointing. + + The handler architecture allows pluggable aggregation strategies while maintaining + consistent interfaces for the MetricsAggregator. + """ + + @abstractmethod + def initialize_metric_state( + self, source: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + """Create a new MetricState for a (source, metric_name) pair. + + Args: + source (str): Name of the source, e.g. the dataset name. Used for logging and disambiguation. + metric_name (str): Name of the metric. + agg_type (AggregationType): Aggregation type. + + Returns: + MetricState: New MetricState for this (source, metric_name) pair. + """ + pass + + @abstractmethod + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + """Update cumulative MetricState with new metric info. + + Args: + local_agg_metric (MetricState): State of the aggregation for this metric in the local rank. + metric (Metric): Input metric info. + """ + pass + + @abstractmethod + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + """ + Computes the final value from the locally aggregated state. For example, for mean + it would mean to divide the tracked sum by the tracked count. + + This method may expand a single metric into multiple, for instance, + a list of numbers into mean, min, max, and percentiles. + + Args: + local_agg_metric (MetricState): The locally aggregated metric state to finalize. + + Returns: + list[MetricState]: List of finalized metric states. + """ + pass + + @abstractmethod + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + """ + Merge MetricStates from all ranks into final result. For example, for 'sum', it would mean to + sum the values from all ranks. + + Args: + local_agg_metrics (list[MetricState]): list of MetricStates from all ranks for a specific + (source, metric_name) tuple after computing finalize_local_agg. + + Returns: + MetricState: Final result for this (source, metric_name) pair. + """ + pass + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert handler-specific metadata to serializable format. Override this when using + non-serializable types like deque or Counter. For example, convert deque to list, Counter to dict. + + Args: + metadata (dict[str, Any]): AggHandler-specific metadata. + + Returns: + dict[str, Any]: Serializable metadata. + """ + return metadata.copy() + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Restore handler-specific metadata from serialized format. Override this to reverse the + serialize_metadata transformation. For example, convert list back to deque, dict back to Counter. + + Args: + metadata (dict[str, Any]): AggHandler-specific metadata. + + Returns: + dict[str, Any]: Deserialized metadata. + """ + return metadata.copy() + + +class SumAggHandler(AggregationHandler): + """AggHandler for SUM aggregation. Initializes with 0.0 and accumulates metric values.""" + + def initialize_metric_state( + self, source: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + source=source, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"SumAggHandler expects numeric values, got {type(metric.value)}" + ) + local_agg_metric.value += metric.value + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + if not local_agg_metrics: + raise ValueError("Cannot aggregate empty list of metrics") + + total = sum(metric.value for metric in local_agg_metrics) + return MetricState( + source=local_agg_metrics[0].source, + metric_name=local_agg_metrics[0].metric_name, + value=total, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy(), + ) + + +class MaxAggHandler(AggregationHandler): + """AggHandler for MAX aggregation. Tracks maximum value across all updates.""" + + def initialize_metric_state( + self, source: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + source=source, + metric_name=metric_name, + value=float("-inf"), + agg_type=agg_type, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"MaxAggHandler expects numeric values, got {type(metric.value)}" + ) + local_agg_metric.value = max(local_agg_metric.value, metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + max_value = max(r.value for r in local_agg_metrics) + return MetricState( + source=local_agg_metrics[0].source, + metric_name=local_agg_metrics[0].metric_name, + value=max_value, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy(), + ) + + +class MinAggHandler(AggregationHandler): + """AggHandler for MIN aggregation. Tracks minimum value across all updates.""" + + def initialize_metric_state( + self, source: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + source=source, + metric_name=metric_name, + value=float("inf"), + agg_type=agg_type, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + if not isinstance(metric.value, (int, float)): + raise ValueError( + f"MinAggHandler expects numeric values, got {type(metric.value)}" + ) + local_agg_metric.value = min(local_agg_metric.value, metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + return [local_agg_metric] + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + min_value = min(r.value for r in local_agg_metrics) + return MetricState( + source=local_agg_metrics[0].source, + metric_name=local_agg_metrics[0].metric_name, + value=min_value, + agg_type=local_agg_metrics[0].agg_type, + metadata=local_agg_metrics[0].metadata.copy(), + ) + + +class MeanAggHandler(AggregationHandler): + """AggHandler for MEAN aggregation. Maintains running sum and count to compute average.""" + + def initialize_metric_state( + self, source: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + source=source, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"sum": 0.0, "count": 0}, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["sum"] += metric.value + local_agg_metric.metadata["count"] += 1 + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + count = local_agg_metric.metadata["count"] + local_agg_metric.value = ( + local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 + ) + return [local_agg_metric] + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + total_sum = sum(metric.metadata["sum"] for metric in local_agg_metrics) + total_count = sum(metric.metadata["count"] for metric in local_agg_metrics) + + return MetricState( + source=local_agg_metrics[0].source, + metric_name=local_agg_metrics[0].metric_name, + value=total_sum / total_count if total_count > 0 else 0.0, + agg_type=local_agg_metrics[0].agg_type, + metadata={"sum": total_sum, "count": total_count}, + ) + + +class StatsAggHandler(AggregationHandler): + """AggHandler for STATS aggregation. Maintains a sliding window of values + and expands into multiple statistical metrics (mean, min, max, percentiles, std). + + Note: Percentiles and standard deviation are approximated in distributed settings by averaging local + percentiles and standard deviations across ranks. This is mathematically imprecise but provides a + reasonable approximation for monitoring purposes. + + Args: + window_size (int): Maximum number of recent values to retain for statistics. + + Raises: + ValueError: If window_size is not positive. + """ + + def __init__(self, window_size: int = 1000): + if window_size <= 0: + raise ValueError(f"window_size must be positive, got {window_size}") + self.window_size = window_size + + def initialize_metric_state( + self, source: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + source=source, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"values": deque(maxlen=self.window_size)}, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["values"].append(metric.value) + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + values = list(local_agg_metric.metadata["values"]) + if not values: + return [] + + values_tensor = torch.tensor(values, dtype=torch.float64) + n = len(values_tensor) + + # Compute stats from the tensor + sum_val = torch.sum(values_tensor).item() + mean_val = sum_val / n + min_val = torch.min(values_tensor).item() + max_val = torch.max(values_tensor).item() + + # Compute percentiles + percentile_definitions = torch.tensor([0.05, 0.5, 0.95], dtype=torch.float64) + p05_val, p50_val, p95_val = torch.quantile( + values_tensor, percentile_definitions + ).tolist() + + # Return multiple MetricStates with proper agg_types for distributed reduction + # NOTE: Percentiles use MEAN aggregation which approximates global percentiles + # by averaging local percentiles. + metrics = [ + MetricState( + source=local_agg_metric.source, + metric_name=f"{local_agg_metric.metric_name}_stat_mean", + value=mean_val, + agg_type=AggregationType.MEAN, + metadata={"sum": sum_val, "count": n}, + ), + MetricState( + source=local_agg_metric.source, + metric_name=f"{local_agg_metric.metric_name}_stat_min", + value=min_val, + agg_type=AggregationType.MIN, + metadata={}, + ), + MetricState( + source=local_agg_metric.source, + metric_name=f"{local_agg_metric.metric_name}_stat_max", + value=max_val, + agg_type=AggregationType.MAX, + metadata={}, + ), + MetricState( + source=local_agg_metric.source, + metric_name=f"{local_agg_metric.metric_name}_stat_p05", + value=p05_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p05_val, "count": 1}, + ), + MetricState( + source=local_agg_metric.source, + metric_name=f"{local_agg_metric.metric_name}_stat_p50", + value=p50_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p50_val, "count": 1}, + ), + MetricState( + source=local_agg_metric.source, + metric_name=f"{local_agg_metric.metric_name}_stat_p95", + value=p95_val, + agg_type=AggregationType.MEAN, + metadata={"sum": p95_val, "count": 1}, + ), + ] + + # Standard deviation is only well-defined for n > 1 + if n > 1: + std_val = torch.std(values_tensor).item() + metrics.append( + MetricState( + source=local_agg_metric.source, + metric_name=f"{local_agg_metric.metric_name}_stat_std", + value=std_val, + agg_type=AggregationType.MEAN, + metadata={"sum": std_val, "count": 1}, + ) + ) + return metrics + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + raise NotImplementedError( + "Metrics with AggregationType.STATS were converted to other " + "AggregationTypes for distributed reduction. finalize_dist_agg should not be called." + ) + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert deque to list for serialization.""" + serialized = metadata.copy() + if "values" in serialized: + serialized["values"] = list(serialized["values"]) + return serialized + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert list back to deque.""" + deserialized = metadata.copy() + if "values" in deserialized: + deserialized["values"] = deque( + deserialized["values"], maxlen=self.window_size + ) + return deserialized + + +class CategoricalCountAggHandler(AggregationHandler): + """AggHandler for CATEGORICAL_COUNT aggregation. Counts occurrences of categorical values + and expands into individual count metrics for each category.""" + + def initialize_metric_state( + self, source: str, metric_name: str, agg_type: AggregationType + ) -> MetricState: + return MetricState( + source=source, + metric_name=metric_name, + value=0.0, + agg_type=agg_type, + metadata={"counts": Counter()}, + ) + + def update(self, local_agg_metric: MetricState, metric: Metric) -> None: + local_agg_metric.metadata["counts"][metric.value] += 1 + + def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: + # Expand categorical counts into individual metrics + results = [] + for category, count in local_agg_metric.metadata["counts"].items(): + results.append( + MetricState( + source=local_agg_metric.source, + metric_name=f"{local_agg_metric.metric_name}_count_{category}", + value=count, + agg_type=AggregationType.SUM, + ) + ) + return results + + def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: + raise NotImplementedError( + "Metrics with AggregationType.CATEGORICAL_COUNT were converted to other " + "AggregationType.SUM for distributed reduction. finalize_dist_agg should not be called." + ) + + def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert Counter to dict for serialization.""" + serialized = metadata.copy() + if "counts" in serialized: + serialized["counts"] = dict(serialized["counts"]) + return serialized + + def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: + """Convert dict back to Counter.""" + deserialized = metadata.copy() + if "counts" in deserialized: + deserialized["counts"] = Counter(deserialized["counts"]) + return deserialized diff --git a/src/forge/data/dataset_metrics/metric_aggregator.py b/src/forge/data/dataset_metrics/metric_aggregator.py new file mode 100644 index 000000000..40d8075ce --- /dev/null +++ b/src/forge/data/dataset_metrics/metric_aggregator.py @@ -0,0 +1,344 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import ast +import logging +from collections import defaultdict +from typing import Any, Union + +import torch.distributed as dist + +from .metric_agg_handlers import ( + AggregationHandler, + CategoricalCountAggHandler, + MaxAggHandler, + MeanAggHandler, + MetricState, + MinAggHandler, + StatsAggHandler, + SumAggHandler, +) +from .metric_transform import AggregationType, Metric + +logger = logging.getLogger(__name__) + + +class MetricsAggregator: + """Aggregates metrics across datasets and distributed ranks using pluggable handlers. + + This class uses a handler-based strategy, where each aggregation type (SUM, MEAN, etc.) + has a corresponding AggregationHandler. It maintains a single state object for each + (source, metric_name) pair. + + Internal State Visualization: + { + ("alpaca", "tokens_seen"): MetricState(value=200.0, agg_type=SUM, ...), + ("alpaca", "avg_loss"): MetricState(value=0.01, agg_type=MEAN, metadata={'sum': ..., 'count': ...}), + ("slim_orca", "seq_len"): MetricState(agg_type=STATS, metadata={'values': deque([...])}), + } + + When preparing metrics for logging, the aggregator follows a two-phase process: + 1. Local Aggregation: Each rank aggregates its metrics independently + 2. Distributed Reduction: If in distributed mode, results are combined across ranks + + The aggregator's state is checkpointable, allowing training resumption. + + Args: + dist_window_size (int): Window size for StatsAggHandler tracking. + + Example: + >>> from forge.data.metrics import MetricsAggregator, Metric, AggregationType + >>> + >>> aggregator = MetricsAggregator() + >>> + >>> # Sample metrics from different batches + >>> batch1_metrics = [ + ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... ] + >>> + >>> batch2_metrics = [ + ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), + ... ] + >>> + >>> # Update with metrics + >>> aggregator.update(batch1_metrics) + >>> aggregator.update(batch2_metrics) + >>> + >>> # Get final results + >>> results = aggregator.get_metrics_for_logging(prefix="train") + >>> # {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} + + Raises: + ValueError: If dist_window_size is not positive. + """ + + def __init__(self, dist_window_size: int = 1000): + if dist_window_size <= 0: + raise ValueError( + f"dist_window_size must be positive, got {dist_window_size}" + ) + + # Storage: {(source, metric_name): MetricState} - O(unique metrics) not O(samples) + self._metric_states: dict[tuple[str, str], MetricState] = {} + self._dist_window_size = dist_window_size + + # Track aggregation types for validation - prevents same metric name with different agg types + self._metric_agg_types: dict[tuple[str, str], AggregationType] = {} + + # Create handler registry - all handlers initialized upfront + self._handlers: dict[AggregationType, AggregationHandler] = { + AggregationType.SUM: SumAggHandler(), + AggregationType.MAX: MaxAggHandler(), + AggregationType.MIN: MinAggHandler(), + AggregationType.MEAN: MeanAggHandler(), + AggregationType.STATS: StatsAggHandler(dist_window_size), + AggregationType.CATEGORICAL_COUNT: CategoricalCountAggHandler(), + } + + def _validate_metric_consistency(self, metric: Union[Metric, MetricState]) -> None: + """Validate that metric name uses consistent aggregation type.""" + metric_key = (metric.source, metric.metric_name) + metric_name = metric.metric_name + + if metric_key in self._metric_agg_types: + existing_agg_type = self._metric_agg_types[metric_key] + if existing_agg_type != metric.agg_type: + raise ValueError( + f"Metric '{metric_name}' in dataset '{metric.source}' " + f"is already registered with aggregation type {existing_agg_type.value}, " + f"but a handler or user code tried to use it with type {metric.agg_type.value}. " + f"Use different metric names for different aggregation types." + ) + else: + # Track this metric's aggregation type + self._metric_agg_types[metric_key] = metric.agg_type + + def register_handler( + self, agg_type: AggregationType, handler: AggregationHandler + ) -> None: + """Register custom aggregation handler for specified type. + + Args: + agg_type (AggregationType): The aggregation type to handle + handler (AggregationHandler): Handler instance implementing the AggregationHandler interface + """ + # Warn if replacing a handler that's already in use + if agg_type in self._handlers and any( + state.agg_type == agg_type for state in self._metric_states.values() + ): + logger.warning( + f"Replacing handler for {agg_type} - aggregation type already in use by existing metrics. " + f"This may affect existing metric behavior." + ) + + self._handlers[agg_type] = handler + + def update(self, metrics: list[Metric]) -> None: + """Update (source, metric_name) metric state with new values. + + Args: + metrics (list[Metric]): List of metrics to update the state with + + Raises: + ValueError: If no handler is registered for a metric's aggregation type, + or if metric name conflicts with existing aggregation type. + """ + for metric in metrics: + # Same metric name must use same aggregation type + self._validate_metric_consistency(metric) + + metric_key = (metric.source, metric.metric_name) + handler = self._handlers.get(metric.agg_type) + + if handler is None: + raise ValueError( + f"No handler registered for aggregation type: {metric.agg_type}" + ) + + if metric_key not in self._metric_states: + self._metric_states[metric_key] = handler.initialize_metric_state( + metric.source, metric.metric_name, metric.agg_type + ) + + local_agg_metric = self._metric_states[metric_key] + handler.update(local_agg_metric, metric) # Mutates local_agg_metric + + def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: + """Get final metrics for logging in standard format. + + Args: + prefix (str): Prefix for metric names in the returned dictionary + + Returns: + dict[str, float]: Dictionary with keys like "{prefix}_{source}/{metric_name}" + and float values. For example, with `prefix="train"`, `source="alpaca"`, + `metric_name="loss"`, the key would be `train_alpaca/loss`. + """ + final_results = self._compute_unified_metrics() + + return { + f"{prefix}_{result.source}/{result.metric_name}": result.value + for result in final_results + } + + def _compute_unified_metrics(self) -> list[MetricState]: + """ + Compute metrics handling both local and distributed cases uniformly. + + Returns: + list[MetricState]: Final results ready for logging + """ + # Step 1: Get local results from all handlers (may expand stats/categoricals) + prepared_results = [] + for local_agg_metric in self._metric_states.values(): + handler = self._handlers[local_agg_metric.agg_type] + generated_metrics = handler.finalize_local_agg(local_agg_metric) + + # Validate each newly generated metric state immediately + for gen_metric in generated_metrics: + self._validate_metric_consistency(gen_metric) + + prepared_results.extend(generated_metrics) + + # Step 2: Apply distributed reduction if needed + if dist.is_initialized() and dist.get_world_size() > 1: + prepared_results = self._finalize_dist_agg(prepared_results) + + return prepared_results + + def _finalize_dist_agg( + self, local_agg_metrics: list[MetricState] + ) -> list[MetricState]: + """Apply distributed reduction to local results. + + Args: + local_agg_metrics (list[MetricState]): (source, metric_name) metric pairs from this rank + + Returns: + list[MetricState]: Reduced results combining all ranks + """ + world_size = dist.get_world_size() + + # Gather all results from all ranks + all_results = [None] * world_size + dist.all_gather_object(all_results, local_agg_metrics) + + # Group by (source, metric_name) for reduction + grouped = defaultdict(list) + for rank_results in all_results: + if rank_results: # Handle ranks with no metrics + for result in rank_results: + result_key = (result.source, result.metric_name) + grouped[result_key].append(result) + + # Apply handler-specific distributed reduction + reduced_results = [] + for result_key, results_list in grouped.items(): + if not results_list: + continue # Skip empty groups + + # All results for a key should have same agg_type + agg_type = results_list[0].agg_type + handler = self._handlers[agg_type] + reduced_result = handler.finalize_dist_agg(results_list) + reduced_results.append(reduced_result) + + return reduced_results + + def state_dict(self) -> dict[str, Any]: + """Serialize aggregator state for checkpointing. + + Returns: + dict[str, Any]: Serializable dictionary containing all aggregator state + """ + serializable_state = {} + required_agg_types = set() # Track aggregation types used in saved states + + for metric_key, local_agg_metric in self._metric_states.items(): + # Get handler for this result's aggregation type + handler = self._handlers[local_agg_metric.agg_type] + required_agg_types.add(local_agg_metric.agg_type) + + # Convert MetricState to serializable dict + result_dict = { + "source": local_agg_metric.source, + "metric_name": local_agg_metric.metric_name, + "value": local_agg_metric.value, + "agg_type": local_agg_metric.agg_type, + "metadata": handler.serialize_metadata(local_agg_metric.metadata), + } + + # Convert tuple key to string for JSON compatibility + serializable_state[str(metric_key)] = result_dict + + return { + "state": serializable_state, + "dist_window_size": self._dist_window_size, + "required_agg_types": list( + required_agg_types + ), # Save which handlers are needed + # Save which aggregation types are used for each metric + "metric_agg_types": { + str(k): v.value for k, v in self._metric_agg_types.items() + }, + } + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """Load aggregator state from checkpoint. + + Args: + state_dict (dict[str, Any]): Dictionary containing serialized aggregator state + + Raises: + ValueError: If required handlers are missing after checkpoint restore + """ + self._dist_window_size = state_dict.get("dist_window_size", 1000) + + # Sanity check: Ensure all required handlers are available + required_agg_types = state_dict.get("required_agg_types", []) + missing_handlers = [] + for agg_type in required_agg_types: + if agg_type not in self._handlers: + missing_handlers.append(agg_type) + + if missing_handlers: + raise ValueError( + f"Missing handlers for aggregation types: {missing_handlers}. " + f"Custom handlers must be re-registered before checkpoint restore." + ) + + deserialized_state = {} + for key_str, result_dict in state_dict["state"].items(): + # Convert string keys back to tuples + metric_key = ast.literal_eval(key_str) + + # Get handler for this aggregation type + agg_type = result_dict["agg_type"] + handler = self._handlers[agg_type] + + # Restore metadata using handler-specific deserialization + metadata = handler.deserialize_metadata(result_dict["metadata"]) + + # Create MetricState from dict + local_agg_metric = MetricState( + source=result_dict["source"], + metric_name=result_dict["metric_name"], + value=result_dict["value"], + agg_type=result_dict["agg_type"], + metadata=metadata, + ) + + deserialized_state[metric_key] = local_agg_metric + + self._metric_states = deserialized_state + + # Restore validation state + self._metric_agg_types = {} + for key_str, agg_type_str in state_dict.get("metric_agg_types", {}).items(): + key = ast.literal_eval(key_str) + self._metric_agg_types[key] = AggregationType(agg_type_str) diff --git a/src/forge/data/dataset_metrics/metric_transform.py b/src/forge/data/dataset_metrics/metric_transform.py new file mode 100644 index 000000000..a9af39ade --- /dev/null +++ b/src/forge/data/dataset_metrics/metric_transform.py @@ -0,0 +1,148 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, Union + + +@dataclass(frozen=True) +class Metric: + source: str + metric_name: str + value: Union[int, float, str] + agg_type: "AggregationType" + + +class AggregationType(Enum): + """Defines how a metric's value should be aggregated by the MetricsAggregator. + + Each type corresponds to a specific AggregationHandler that implements the logic + for initialization, updates, and distributed reduction. + """ + + SUM = "sum" + MEAN = "mean" + STATS = "distribution" + CATEGORICAL_COUNT = "categorical_count" + MAX = "max" + MIN = "min" + + +class MetricTransform(ABC): + """Applied to each dataset sample to generate per-sample metrics for training tracking. + + Creates Metric objects that are later aggregated by MetricsAggregator. This separation + of concerns ensures metrics are correctly aggregated even with multiple dataloader + workers and in distributed settings. + + The transform must be configured with a source via set_source() before use. + Each call to __call__ adds metrics to the sample's "metrics" key. + + Example: + >>> transform = DefaultTrainingMetricTransform() + >>> transform.set_source("alpaca") + >>> sample = {"tokens": [1, 2, 3]} + >>> result = transform(sample) + >>> # result["metrics"] contains list of Metric objects + """ + + def set_source(self, source: str) -> None: + """Called by the dataset to set the namespace for metrics. + + This is used to differentiate metrics from multiple datasets, for example, + "alpaca/tokens_seen" vs. "slim_orca/tokens_seen". + + Args: + source (str): Name of the dataset, used for logging and disambiguation. + """ + self.source = source + + @abstractmethod + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + """Generate metrics for a single sample. + + Args: + sample (dict[str, Any]): The sample dictionary to generate metrics from + + Returns: + list[Metric]: List of metrics generated for this sample + + Raises: + NotImplementedError: If subclass does not implement this method. + """ + raise NotImplementedError("Subclasses must implement _generate_metrics method") + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + if not hasattr(self, "source"): + raise RuntimeError( + "'transform.set_source' must be called before using the transform." + ) + + # Generate metrics for this sample + metrics = self._generate_metrics(sample) + + # Add to existing metrics list or create new one + if "metrics" not in sample: + sample["metrics"] = [] + sample["metrics"].extend(metrics) + return sample + + +class DefaultTrainingMetricTransform(MetricTransform): + """Generates common training metrics: samples seen, tokens seen, and sequence length. + + This transform detects the token key in a sample, checking for "tokens" + first and then falling back to "input_ids". + + For details on the base class behavior, see MetricTransform. + + Tracked metrics: + - samples_seen: Cumulative count of samples processed (SUM aggregation) + - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) + - seq_len: Distribution stats of sequence lengths (STATS aggregation) + + Example: + >>> transform = DefaultTrainingMetricTransform() + >>> transform.set_source("alpaca") + >>> + >>> sample = {"tokens": [1, 2, 3, 4, 5]} # 5 tokens + >>> metrics = transform._generate_metrics(sample) + >>> # This generates the following Metric objects: + >>> # [ + >>> # Metric(source="alpaca", metric_name="samples_seen", value=1, agg_type=AggregationType.SUM), + >>> # Metric(source="alpaca", metric_name="tokens_seen", value=5, agg_type=AggregationType.SUM), + >>> # Metric(source="alpaca", metric_name="seq_len", value=5, agg_type=AggregationType.STATS) + >>> # ] + """ + + def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: + # Determine token key + token_key = "tokens" if "tokens" in sample else "input_ids" + token_len = len(sample.get(token_key, [])) + + # Create metrics for this sample + return [ + Metric( + source=self.source, + metric_name="samples_seen", + value=1, + agg_type=AggregationType.SUM, + ), + Metric( + source=self.source, + metric_name="tokens_seen", + value=token_len, + agg_type=AggregationType.SUM, + ), + Metric( + source=self.source, + metric_name="seq_len", + value=token_len, + agg_type=AggregationType.STATS, + ), + ] diff --git a/src/forge/data/dataset_metrics/readme.md b/src/forge/data/dataset_metrics/readme.md new file mode 100644 index 000000000..76ec424e3 --- /dev/null +++ b/src/forge/data/dataset_metrics/readme.md @@ -0,0 +1,176 @@ +# forge Metrics Module + +## Overview + +The metrics module provides a robust system for tracking and aggregating training metrics across multiple datasets and distributed environments. It follows a **strategy pattern** design with pluggable aggregation handlers to efficiently handle different types of metrics. + +## Architecture Overview + +``` +┌────────────────────────────────────────────────────┐ +│ Training Loop │ +└─────────────────────┬──────────────────────────────┘ + │ +┌─────────────────────▼──────────────────────────────┐ +│ MetricTransform │ +│ • Applied to each sample │ +│ • Generates per-sample metrics │ +│ • Examples: tokens_seen, seq_len, samples_seen │ +└─────────────────────┬──────────────────────────────┘ + │ list[Metric] +┌─────────────────────▼──────────────────────────────┐ +│ MetricsAggregator │ +│ • Aggregates metrics across samples and ranks │ +│ • Uses pluggable AggregationHandlers │ +│ • Handles distributed reduction │ +└─────────────────────┬──────────────────────────────┘ + │ {prefix}_{source}/{metric_name} # prefix is "train", "val", etc. +┌─────────────────────▼──────────────────────────────┐ +│ Logging System │ +│ • W&B, TensorBoard, etc. │ +│ • Gets formatted metrics ready for logging │ +└────────────────────────────────────────────────────┘ +``` + +## File Structure + +- **`metric_transform.py`**: Defines `Metric`, `AggregationType`, and transform classes +- **`metric_agg_handlers.py`**: Aggregation strategy implementations +- **`metric_aggregator.py`**: Main aggregator orchestrating the handlers + +## Customizing metrics + +- **Custom transforms**: Extend `MetricTransform` for domain-specific metrics +- **Handler registration**: Register custom handlers for specialized aggregation needs + +####### +## TODO +## Move this from here to website docs +####### + +## Core Components + +### 1. MetricTransform +Generates per-sample metrics during data processing. + +**Key Features:** +- Applied to each sample in the dataset +- Creates `Metric` objects with dataset name, metric name, value, and aggregation type +- Handles dataset namespacing for multi-dataset scenarios + +**Example Usage:** +```python +from forge.data.metrics import DefaultTrainingMetricTransform, AggregationType + +transform = DefaultTrainingMetricTransform() +transform.set_source("alpaca") + +# Applied to each sample +sample = {"tokens": [1, 2, 3, 4, 5]} +sample = transform(sample) +# sample["metrics"] now contains: +# [ +# Metric(source="alpaca", name="samples_seen", value=1, agg_type=AggregationType.SUM), +# Metric(source="alpaca", name="tokens_seen", value=5, agg_type=AggregationType.SUM), +# Metric(source="alpaca", name="seq_len", value=5, agg_type=AggregationType.STATS) +# ] +``` + +### 2. MetricsAggregator +Efficiently aggregates metrics across samples and distributed ranks. + +**Key Features:** +- Handler-based strategy pattern for different aggregation types +- Distributed-aware with automatic rank reduction +- Checkpointable state for training resumption +- Keep track of (metric, dataset) pairs + +**Aggregation Types (at the time of writing):** +- `SUM`: Cumulative totals (e.g., total tokens processed) +- `MEAN`: Running averages (e.g., average loss) +- `MAX/MIN`: Extrema tracking (e.g., max sequence length seen) +- `STATS`: Statistical summaries (mean, min, max, percentiles) +- `CATEGORICAL_COUNT`: Category cumulative counts (e.g. num of samples from a given category) + +**Example Usage:** +```python +from forge.data.metrics import MetricsAggregator, Metric, AggregationType + +# Create aggregator +aggregator = MetricsAggregator() + +# Sample metrics from different batches +batch1_metrics = [ + Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), +] + +batch2_metrics = [ + Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), + Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), +] + +# Update with metrics +aggregator.update(batch1_metrics) +aggregator.update(batch2_metrics) + +# Get final results +results = aggregator.get_metrics_for_logging(prefix="train") +# {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} +``` + +### 3. AggregationHandlers +Pluggable strategies for different aggregation patterns. + +``` +AggregationHandler (ABC) +├── SumAggHandler # value += metric.value +├── MeanAggHandler # tracks sum and count +├── MaxAggHandler # value = max(value, metric.value) +├── MinAggHandler # value = min(value, metric.value) +├── StatsAggHandler # maintains value window + stats +└── CategoricalCountAggHandler # Counter for categories +``` + +**Custom Handler Example:** +```python +class CustomAggHandler(AggregationHandler): + def initialize_metric_state(self, source, metric_name, agg_type): + return MetricState( + source=source, + metric_name=metric_name, + value=, # should change + agg_type=agg_type, + metadata={} # may need to change + ) + + def update(self, local_agg_metric, metric): + ... + + def finalize_local_agg(self, local_agg_metric): + ... + + def finalize_dist_agg(self, local_agg_metrics): + ... + +# Register with aggregator +aggregator.register_handler(AggregationType.CUSTOM, CustomAggHandler()) +``` + +## Distributed Training Support + +The metrics system automatically handles distributed environments: + +1. **Local Aggregation**: Each rank aggregates its own metrics +2. **Distributed Reduction**: Results are combined across ranks using `all_gather_object` +3. **Type-Aware Reduction**: Each aggregation type uses appropriate reduction (sum, mean, max, etc.) + +**Distributed Flow:** +``` +Rank 0: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] +Rank 1: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] + ↓ + AllGather + Reduce + ↓ + Final Results [(ds1, metric1), (ds1, metric2)] +``` diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py index d7b36fe68..799dd89b9 100644 --- a/src/forge/data/datasets/hf_dataset.py +++ b/src/forge/data/datasets/hf_dataset.py @@ -12,8 +12,12 @@ from datasets import load_dataset from datasets.distributed import split_dataset_by_node -from forge.data.metric_transform import DefaultDatasetMetricTransform, MetricTransform -from forge.observability.metrics import Metric, Reduce +from forge.data.dataset_metrics import ( + AggregationType, + DefaultTrainingMetricTransform, + Metric, + MetricTransform, +) from .dataset import DatasetInfo, InfiniteTuneIterableDataset @@ -81,7 +85,7 @@ def __init__( self._weight = weight if weight is not None else 1.0 # Create default transform if not provided - self._metric_transform = metric_transform or DefaultDatasetMetricTransform() + self._metric_transform = metric_transform or DefaultTrainingMetricTransform() # Auto-generate dataset name if not provided if dataset_name is None: @@ -235,16 +239,15 @@ def __iter__(self) -> Iterator[dict[str, Any]]: # Track the number of epochs completed for each dataset. This is # especially useful when interleaving multiple datasets, but # also necessary to track dataset-level metrics. + metric_num_epochs = Metric( + source=self.info.name, + metric_name="num_epochs", + value=self._num_epochs, + agg_type=AggregationType.MAX, + ) if "metrics" not in sample: sample["metrics"] = [] - - sample["metrics"].append( - Metric( - key=f"dataset/{self.info.name}/num_epochs", - value=self._num_epochs, - reduction=Reduce.MAX, - ) - ) + sample["metrics"].append(metric_num_epochs) samples_yielded += 1 yield sample diff --git a/src/forge/data/datasets/packed.py b/src/forge/data/datasets/packed.py index 93a21b85e..d09c158c0 100644 --- a/src/forge/data/datasets/packed.py +++ b/src/forge/data/datasets/packed.py @@ -16,7 +16,7 @@ from torchdata.stateful_dataloader import Stateful from forge.data import CROSS_ENTROPY_IGNORE_IDX -from forge.observability.metrics import Metric, Reduce +from forge.data.dataset_metrics import AggregationType, Metric from .dataset import DatasetInfo, InfiniteTuneIterableDataset @@ -605,13 +605,13 @@ def finalize_pack( # Add padding percentage metric if target_tokens_per_pack > 0: padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2) - pack["metrics"].append( - Metric( - key=f"dataset/{self.dataset_name}/pct_of_tokens_padded", - value=padding_pct, - reduction=Reduce.MEAN, - ) + padding_metric = Metric( + source=self.dataset_name, + metric_name="pct_of_tokens_padded", + value=padding_pct, + agg_type=AggregationType.MEAN, ) + pack["metrics"].append(padding_metric) # Concatenate tensor lists and handle other keys result = { @@ -635,7 +635,7 @@ def finalize_pack( if pack["input_pos"] else torch.empty(0, dtype=torch.long) ), - "metrics": pack["metrics"], + # "metrics": pack["metrics"], } # Handle arbitrary keys that aren't tensors - keep as lists @@ -853,13 +853,13 @@ def finalize_pack( # Add padding percentage metric if target_tokens_per_pack > 0: padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2) - pack["metrics"].append( - Metric( - key=f"dataset/{self.dataset_name}/pct_of_tokens_padded", - value=padding_pct, - reduction=Reduce.MEAN, - ) + padding_metric = Metric( + source=self.dataset_name, + metric_name="pct_of_tokens_padded", + value=padding_pct, + agg_type=AggregationType.MEAN, ) + pack["metrics"].append(padding_metric) # Concatenate tensor lists and handle other keys result = { diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index 00278c1e5..3a2574643 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -9,7 +9,7 @@ import torch from forge.data import CROSS_ENTROPY_IGNORE_IDX -from forge.data.metric_transform import DefaultDatasetMetricTransform +from forge.data.dataset_metrics import DefaultTrainingMetricTransform from forge.data.utils import mask_messages, TuneMessage from .hf_dataset import HfIterableDataset @@ -198,7 +198,7 @@ def sft_iterable_dataset( message_transform=message_transform, model_transform=model_transform, output_transform=output_transform, - metric_transform=DefaultDatasetMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), shuffle_buffer_size=shuffle_buffer_size, weight=weight, seed=seed, diff --git a/src/forge/data/metric_transform.py b/src/forge/data/metric_transform.py deleted file mode 100644 index cbbc04020..000000000 --- a/src/forge/data/metric_transform.py +++ /dev/null @@ -1,113 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from typing import Any - -from forge.observability.metrics import Metric, Reduce - - -class MetricTransform: - """ - Base class for transforms that collect observability metrics from dataset samples. - - This class provides a foundation for implementing dataset-level metric collection - during data processing pipelines. Subclasses should override the __call__ method - to add specific metrics to each sample that passes through the transform. - - Metrics are collected as `forge.observability.metrics.Metric` objects and made available - in batch["metrics"]. - - Attributes: - source (str, optional): The source name for metrics, typically the dataset name. - This is used as a prefix in metric keys to distinguish metrics from different - data sources. - - Example: - >>> transform = SomeMetricTransform() - >>> transform.set_source("training_data") - >>> processed_sample = transform(sample) - >>> # Metrics are automatically added to sample["metrics"] - """ - - def __init__(self): - self.source = None - - def set_source(self, source: str): - """Set the source name for metrics (typically the dataset name).""" - self.source = source - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - """Transform a sample by adding metrics to it.""" - return sample - - -class DefaultDatasetMetricTransform(MetricTransform): - """ - Collects basic dataset processing metrics during data pipeline execution. - - Metrics collected: - - samples_processed: Total number of samples that have passed through this transform (SUM) - - tokens_processed: Total number of tokens processed across all samples (SUM) - - mean_seq_len: Average sequence length across samples (MEAN) - - max_seq_len: Maximum sequence length observed (MAX) - - min_seq_len: Minimum sequence length observed (MIN) - - Note: Token-related metrics are only collected if the sample contains a 'tokens' field. - Sequence length is measured as the number of tokens in each sample. - - Example: - >>> collector = DefaultDatasetMetricTransform() - >>> collector.set_source("training_data") - >>> sample = {"tokens": ["hello", "world"]} - >>> processed_sample = collector(sample) - >>> # Metrics are automatically added to processed_sample["metrics"] - """ - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - if "metrics" not in sample: - sample["metrics"] = [] - - source_name = self.source or "unnamed_ds" - - # Add samples_processed metric - sample["metrics"].append( - Metric( - key=f"dataset/{source_name}/samples_processed", - value=1, - reduction=Reduce.SUM, - ) - ) - - # Add token-based metrics if tokens are present - if "tokens" in sample: - token_count = len(sample.get("tokens", [])) - - sample["metrics"].extend( - [ - Metric( - key=f"dataset/{source_name}/tokens_processed", - value=token_count, - reduction=Reduce.SUM, - ), - Metric( - key=f"dataset/{source_name}/mean_seq_len", - value=token_count, - reduction=Reduce.MEAN, - ), - Metric( - key=f"dataset/{source_name}/max_seq_len", - value=token_count, - reduction=Reduce.MAX, - ), - Metric( - key=f"dataset/{source_name}/min_seq_len", - value=token_count, - reduction=Reduce.MIN, - ), - ] - ) - - return sample diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index d01617ba3..0ba2b8c73 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -82,7 +82,6 @@ async def get_or_create_metric_logger( # Shutdown await mlogger.shutdown.call_one() """ - # Get or create the singleton global logger global _global_logger diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index 35e12ab3f..c5ca7cf1f 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -544,7 +544,7 @@ def push(self, metric: Metric) -> None: " Metric logging backends (e.g. wandb) were not initialized." " This happens when you try to use `record_metric` before calling `init_backends`." " To disable this warning, please call in your main file:\n" - "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" + "`mlogger = await get_or_create_metric_logger()`\n" "`await mlogger.init_backends.call_one(logging_config)`\n" "or set env variable `FORGE_DISABLE_METRICS=True`" ), diff --git a/tests/unit_tests/data/__init__.py b/tests/unit_tests/data/__init__.py new file mode 100644 index 000000000..2e41cd717 --- /dev/null +++ b/tests/unit_tests/data/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/tests/unit_tests/data/test_metrics_aggregator.py b/tests/unit_tests/data/test_metrics_aggregator.py new file mode 100644 index 000000000..5b847c92f --- /dev/null +++ b/tests/unit_tests/data/test_metrics_aggregator.py @@ -0,0 +1,456 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for MetricsAggregator functionality. + +This module tests the metrics collection and aggregation system including: +- All aggregation types (SUM, MEAN, MAX, MIN, STATS, CATEGORICAL_COUNT) +- State management and checkpointing +- Multi-dataset metric namespacing +- Distributed metrics aggregation +- Metric consistency validation + +Uses synthetic metrics to verify correct aggregation behavior across scenarios. +""" + +import logging + +import pytest +import torch.distributed as dist + +from forge.data.dataset_metrics import AggregationType, Metric, MetricsAggregator +from torch.testing._internal.common_fsdp import FSDPTest + +from tests.test_utils import gpu_test + + +class TestMetricsAggregator: + """Tests for MetricsAggregator core functionality and edge cases.""" + + @pytest.mark.parametrize( + "agg_type,test_values,expected", + [ + (AggregationType.SUM, [1, 2, 3, 4], 10), + (AggregationType.MEAN, [10, 20, 30, 40], 25.0), + (AggregationType.MAX, [-5, 10, 3, 15], 15), + (AggregationType.MIN, [5, -2, 8, 1], -2), + ( + AggregationType.CATEGORICAL_COUNT, + ["A", "B", "A", "C", "A"], + {"A": 3, "B": 1, "C": 1}, + ), + ], + ) + def test_aggregation_types(self, agg_type, test_values, expected): + """Tests each AggregationType with representative data to verify correct computation. + + Covers aggregation types: + - SUM: Simple addition across values + - MEAN: Average computation with proper count tracking + - MAX/MIN: Extrema identification + - CATEGORICAL_COUNT: Category frequency counting + """ + aggregator = MetricsAggregator() + + metrics = [ + Metric(source="test", metric_name="metric", value=val, agg_type=agg_type) + for val in test_values + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging(prefix="train") + + if agg_type == AggregationType.CATEGORICAL_COUNT: + for category, count in expected.items(): + assert result[f"train_test/metric_count_{category}"] == count + else: + assert result["train_test/metric"] == expected + + def test_stats_metrics(self): + """Tests that STATS aggregation computes statistics (mean, min, max, percentiles).""" + aggregator = MetricsAggregator() + values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + + metrics = [ + Metric("test", "dist_metric", val, AggregationType.STATS) for val in values + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging(prefix="train") + + assert result["train_test/dist_metric_stat_mean"] == 5.5 + assert result["train_test/dist_metric_stat_min"] == 1 + assert result["train_test/dist_metric_stat_max"] == 10 + assert result["train_test/dist_metric_stat_p50"] == 5.5 + + def test_state_management(self): + """Test metrics aggregator state persistence and restoration for checkpointing scenarios.""" + # Create aggregator with mixed metric types to test state saving + aggregator1 = MetricsAggregator() + initial_metrics = [ + Metric("ds1", "counter", 10, AggregationType.SUM), + Metric("ds1", "average", 5.0, AggregationType.MEAN), + Metric("ds2", "categories", "X", AggregationType.CATEGORICAL_COUNT), + ] + aggregator1.update(initial_metrics) + + # Save state + state = aggregator1.state_dict() + + # Create new aggregator and restore state + aggregator2 = MetricsAggregator() + aggregator2.load_state_dict(state) + + # Both should have identical metrics + metrics1 = aggregator1.get_metrics_for_logging(prefix="train") + metrics2 = aggregator2.get_metrics_for_logging(prefix="train") + assert metrics1 == metrics2 + + # Continue updating both - should remain identical + additional_metrics = [ + Metric("ds1", "counter", 5, AggregationType.SUM), + Metric("ds1", "average", 15.0, AggregationType.MEAN), + ] + aggregator1.update(additional_metrics) + aggregator2.update(additional_metrics) + + final_metrics1 = aggregator1.get_metrics_for_logging(prefix="train") + final_metrics2 = aggregator2.get_metrics_for_logging(prefix="train") + assert final_metrics1 == final_metrics2 + + # Verify expected values + assert final_metrics1["train_ds1/counter"] == 15 # 10 + 5 + assert final_metrics1["train_ds1/average"] == 10.0 # (5 + 15) / 2 + + def test_multiple_datasets(self): + """Test that metrics from multiple datasets are correctly namespaced.""" + aggregator = MetricsAggregator() + + metrics = [ + Metric("dataset1", "samples", 100, AggregationType.SUM), + Metric("dataset2", "samples", 200, AggregationType.SUM), + Metric("dataset1", "tokens", 1000, AggregationType.SUM), + Metric("dataset2", "tokens", 2000, AggregationType.SUM), + ] + aggregator.update(metrics) + + result = aggregator.get_metrics_for_logging(prefix="train") + + assert result["train_dataset1/samples"] == 100 + assert result["train_dataset2/samples"] == 200 + assert result["train_dataset1/tokens"] == 1000 + assert result["train_dataset2/tokens"] == 2000 + + def test_empty_aggregator(self): + """Test that empty aggregator returns empty metrics.""" + aggregator = MetricsAggregator() + result = aggregator.get_metrics_for_logging(prefix="train") + assert result == {} + + def test_prefix_handling(self): + """Test that prefix is correctly applied to metric keys.""" + aggregator = MetricsAggregator() + metrics = [ + Metric("test_ds", "metric1", 42, AggregationType.SUM), + Metric("test_ds", "metric2", 84, AggregationType.SUM), + ] + aggregator.update(metrics) + + # Test with prefix + result_with_prefix = aggregator.get_metrics_for_logging(prefix="validation") + assert result_with_prefix["validation_test_ds/metric1"] == 42 + assert result_with_prefix["validation_test_ds/metric2"] == 84 + + # Test without prefix (uses default "data") + result_no_prefix = aggregator.get_metrics_for_logging() + assert result_no_prefix["data_test_ds/metric1"] == 42 + assert result_no_prefix["data_test_ds/metric2"] == 84 + + def test_metric_consistency_validation(self): + """Test that same metric name must use same aggregation type.""" + aggregator = MetricsAggregator() + + # First metric with SUM aggregation + metrics1 = [Metric("test", "my_metric", 10, AggregationType.SUM)] + aggregator.update(metrics1) + + # Try to use same metric name with different aggregation type - should fail + metrics2 = [Metric("test", "my_metric", 5.0, AggregationType.MEAN)] + with pytest.raises( + ValueError, match="is already registered with aggregation type sum" + ): + aggregator.update(metrics2) + + # Same metric name with same aggregation type should work + metrics3 = [Metric("test", "my_metric", 20, AggregationType.SUM)] + aggregator.update(metrics3) # Should not raise + + result = aggregator.get_metrics_for_logging(prefix="train") + assert result["train_test/my_metric"] == 30 # 10 + 20 + + def test_metric_consistency_across_datasets(self): + """Test that same metric name can use different aggregation types across different datasets.""" + aggregator = MetricsAggregator() + + # Same metric name but different datasets - should be allowed + metrics = [ + Metric("dataset1", "metric", 10, AggregationType.SUM), + Metric("dataset2", "metric", 5.0, AggregationType.MEAN), + ] + aggregator.update(metrics) # Should not raise + + result = aggregator.get_metrics_for_logging(prefix="train") + assert result["train_dataset1/metric"] == 10 + assert result["train_dataset2/metric"] == 5.0 + + def test_handler_generated_metric_validation(self): + """Test that handler-generated metrics are validated for consistency.""" + aggregator = MetricsAggregator() + + # Create a user-defined metric that will conflict with stats + user_metrics = [ + Metric("test", "dist_metric_stat_mean", 42, AggregationType.SUM) + ] + aggregator.update(user_metrics) + + # Now try to add a stats metric that will generate conflicting stat names + dist_metrics = [Metric("test", "dist_metric", 10, AggregationType.STATS)] + aggregator.update(dist_metrics) + + # This should fail when trying to get metrics for logging because the handler + # will try to create "dist_metric_stat_mean" which conflicts with the user metric + with pytest.raises( + ValueError, match="is already registered with aggregation type sum" + ): + aggregator.get_metrics_for_logging(prefix="train") + + def test_handler_replacement_warning(self, caplog): + """Test that replacing handlers in use generates a warning.""" + aggregator = MetricsAggregator() + + # Add a metric that uses SUM aggregation + metrics = [Metric("test", "sum_metric", 10, AggregationType.SUM)] + aggregator.update(metrics) + + # Replace the SUM handler - should generate warning + from forge.data.dataset_metrics import SumAggHandler + + with caplog.at_level(logging.WARNING): + aggregator.register_handler(AggregationType.SUM, SumAggHandler()) + + # Check that the expected warning was logged + assert len(caplog.records) == 1 + assert "Replacing handler for AggregationType.SUM" in caplog.records[0].message + + +class TestDistributedMetricsAggregator(FSDPTest): + """Distributed tests for MetricsAggregator using FSDPTest infrastructure.""" + + @property + def world_size(self) -> int: + return 2 + + @gpu_test(gpu_count=2) + def test_distributed_all_aggregation_types(self): + """ + Test that all aggregation types work correctly in distributed setting. + Each rank contributes different values to ensure proper reduction across ranks. + """ + aggregator = MetricsAggregator() + rank = dist.get_rank() + + # Each rank contributes different values to test cross-rank aggregation + base_value = (rank + 1) * 10 # rank 0: 10, rank 1: 20 + + metrics = [ + Metric("test", "sum_metric", base_value, AggregationType.SUM), + Metric("test", "mean_metric", base_value + 5, AggregationType.MEAN), + Metric("test", "max_metric", base_value * 10, AggregationType.MAX), + Metric("test", "min_metric", base_value // 2, AggregationType.MIN), + ] + + # STATS: Each rank adds 5 values for statistics + # rank 0: [0, 1, 2, 3, 4], rank 1: [10, 11, 12, 13, 14] + for i in range(5): + metrics.append( + Metric("test", "dist_metric", rank * 10 + i, AggregationType.STATS) + ) + + # CATEGORICAL_COUNT: Different categories per rank to test counting + # rank 0: 3 of cat_A, 2 of cat_B + # rank 1: 1 of cat_A, 4 of cat_C + if rank == 0: + metrics.extend( + [ + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT + ), + ] + ) + else: + metrics.extend( + [ + Metric( + "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + Metric( + "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT + ), + ] + ) + + # Update aggregator and get results + aggregator.update(metrics) + result = aggregator.get_metrics_for_logging(prefix="train") + + # Verify aggregation results across all ranks + # SUM: rank 0 adds 10, rank 1 adds 20 -> total 30 + # MEAN: rank 0 has 15, rank 1 has 25 -> avg 20 + # MAX: rank 0 has 100, rank 1 has 200 -> max 200 + # MIN: rank 0 has 5, rank 1 has 10 -> min 5 + assert result["train_test/sum_metric"] == 30 + assert result["train_test/mean_metric"] == 20 + assert result["train_test/max_metric"] == 200 + assert result["train_test/min_metric"] == 5 + + # STATS: Combined values [0,1,2,3,4,10,11,12,13,14] + # Mean should be average of local means: (2 + 12) / 2 = 7 + assert result["train_test/dist_metric_stat_mean"] == 7 + assert result["train_test/dist_metric_stat_min"] == 0 + assert result["train_test/dist_metric_stat_max"] == 14 + + # CATEGORICAL_COUNT: Total counts across ranks + # cat_A: 3(rank0) + 1(rank1) = 4, cat_B: 2(rank0) + 0(rank1) = 2, cat_C: 0(rank0) + 4(rank1) = 4 + assert result["train_test/cat_metric_count_cat_A"] == 4 + assert result["train_test/cat_metric_count_cat_B"] == 2 + assert result["train_test/cat_metric_count_cat_C"] == 4 + + @gpu_test(gpu_count=2) + def test_distributed_state_dict_resumption(self): + """ + Test that MetricsAggregator state_dict save/restore works correctly in distributed setting. + Verifies: + - State can be saved after partial updates across ranks + - State can be restored consistently across ranks + - Continued updates after restore produce identical results + - Distributed aggregation works correctly after restoration + """ + rank = dist.get_rank() + + # Phase 1: Create aggregator and add initial metrics + aggregator1 = MetricsAggregator() + + # Each rank contributes different initial values + base_value = rank * 100 # rank 0: 0, rank 1: 100 + + initial_metrics = [ + Metric("test", "sum_metric", base_value, AggregationType.SUM), + Metric("test", "mean_metric", base_value // 2, AggregationType.MEAN), + Metric("test", "max_metric", base_value * 2, AggregationType.MAX), + ] + + # Add some STATS values - each rank adds 3 values + for i in range(3): + initial_metrics.append( + Metric("test", "dist_metric", rank * 100 + i, AggregationType.STATS) + ) + + # Add CATEGORICAL_COUNT values + if rank == 0: + initial_metrics.extend( + [ + Metric( + "test", + "cat_metric", + "type_A", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_A", + AggregationType.CATEGORICAL_COUNT, + ), + ] + ) + else: + initial_metrics.extend( + [ + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + Metric( + "test", + "cat_metric", + "type_B", + AggregationType.CATEGORICAL_COUNT, + ), + ] + ) + + aggregator1.update(initial_metrics) + + # Save state_dict after initial update + state_dict = aggregator1.state_dict() + + # Phase 2: Create new aggregator and restore from state_dict + aggregator2 = MetricsAggregator() + aggregator2.load_state_dict(state_dict) + + # Verify both aggregators produce identical results after restore + result1 = aggregator1.get_metrics_for_logging(prefix="checkpoint") + result2 = aggregator2.get_metrics_for_logging(prefix="checkpoint") + assert ( + result1 == result2 + ), f"Rank {rank}: Aggregators differ after state_dict restore" + + # Phase 3: Add more metrics to both aggregators + additional_metrics = [ + Metric("test", "sum_metric", rank * 1000, AggregationType.SUM), + Metric("test", "min_metric", rank * 1000, AggregationType.MIN), + ] + + # Update both aggregators with additional metrics + aggregator1.update(additional_metrics) + aggregator2.update(additional_metrics) + + # Phase 4: Verify final results are identical across both aggregators + final_result1 = aggregator1.get_metrics_for_logging(prefix="final") + final_result2 = aggregator2.get_metrics_for_logging(prefix="final") + assert ( + final_result1 == final_result2 + ), f"Rank {rank}: Final results differ after continued updates" diff --git a/tests/unit_tests/data/test_metrics_transform.py b/tests/unit_tests/data/test_metrics_transform.py new file mode 100644 index 000000000..078b511a8 --- /dev/null +++ b/tests/unit_tests/data/test_metrics_transform.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests cover: +- DefaultTrainingMetricTransform +- Basic metric generation (samples_seen, tokens_seen, seq_len) +- Dataset name validation and requirements +- Proper metric type assignment and aggregation configuration +""" + +import pytest + +from forge.data.dataset_metrics import AggregationType, DefaultTrainingMetricTransform + + +class TestDefaultTrainingMetricTransform: + """Tests for DefaultTrainingMetricTransform functionality.""" + + def test_source_not_set_raises_error(self): + """Test that the transform raises a RuntimeError if used before + `set_source` is called, ensuring that metrics are always + correctly attributed to a dataset.""" + transform = DefaultTrainingMetricTransform() + sample = {"tokens": [1, 2, 3]} + + with pytest.raises(RuntimeError, match="set_source"): + transform(sample) + + def test_basic_metrics_generation(self): + """Test that transform generates expected training metrics for input samples.""" + transform = DefaultTrainingMetricTransform() + # Set dataset name required for metric generation + transform.set_source("test_dataset") + + sample = {"tokens": [1, 2, 3, 4, 5]} + result = transform(sample) + + # Transform should preserve original sample data unchanged + assert result["tokens"] == [1, 2, 3, 4, 5] + + # Should generate exactly 3 metrics: samples_seen, tokens_seen, seq_len + assert "metrics" in result + metrics = result["metrics"] + assert len(metrics) == 3 + + # Verify each metric has correct properties and aggregation type + for metric in metrics: + if metric.metric_name == "samples_seen": + assert metric.source == "test_dataset" + assert metric.value == 1 + assert metric.agg_type == AggregationType.SUM + + elif metric.metric_name == "tokens_seen": + assert metric.source == "test_dataset" + assert metric.value == 5 + assert metric.agg_type == AggregationType.SUM + + elif metric.metric_name == "seq_len": + assert metric.source == "test_dataset" + assert metric.value == 5 + assert metric.agg_type == AggregationType.STATS diff --git a/tests/unit_tests/datasets/test_hf.py b/tests/unit_tests/datasets/test_hf.py index c493d21eb..c1535c8b8 100644 --- a/tests/unit_tests/datasets/test_hf.py +++ b/tests/unit_tests/datasets/test_hf.py @@ -26,8 +26,8 @@ import pytest import torch.distributed as dist +from forge.data.dataset_metrics import DefaultTrainingMetricTransform, MetricsAggregator from forge.data.datasets import HfIterableDataset -from forge.data.metric_transform import DefaultDatasetMetricTransform from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader @@ -93,7 +93,7 @@ def _create_dataset( dataset_name=dataset_name, seed=SEED, shuffle_buffer_size=10 if shuffle else 0, - metric_transform=DefaultDatasetMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, **kwargs, ) @@ -113,7 +113,7 @@ def test_default_dataset_name(self, small_dataset_file): split="train", # dataset_name not provided - should auto-generate seed=SEED, - metric_transform=DefaultDatasetMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=4, ) @@ -131,7 +131,7 @@ def test_default_dataset_name(self, small_dataset_file): dataset_name="my_dataset", weight=custom_weight, seed=SEED, - metric_transform=DefaultDatasetMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=4, ) @@ -149,16 +149,17 @@ def test_epoch_boundaries_and_checkpointing( the epoch metric is correct, and checkpointing works as expected. """ - # 1. Setup Dataloaders for original and resumed runs - def create_loader(): + # 1. Setup Dataloaders and Aggregators for original and resumed runs + def create_loader_and_aggregator(): dataset = dataset_factory(small_dataset_file, shuffle=False) loader = StatefulDataLoader( dataset, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics ) - return loader + aggregator = MetricsAggregator() + return loader, aggregator - loader1 = create_loader() - loader2 = create_loader() + loader1, aggregator1 = create_loader_and_aggregator() + loader2, aggregator2 = create_loader_and_aggregator() # 2. Calculate steps for the test run total_samples = int(SMALL_DATASET_SIZE * num_epochs) @@ -170,9 +171,11 @@ def create_loader(): # 3. Generate checkpoint and resume result = generate_ckpt( loader1, + aggregator1, steps_before_checkpoint=steps_before_checkpoint, steps_after_checkpoint=steps_after_checkpoint, resume_dataloader=loader2, + resume_aggregator=aggregator2, ) # 4. Verify checkpointing and resumption @@ -181,10 +184,9 @@ def create_loader(): assert ( orig_post_ids == resumed_ids ), "Resumed batches should be identical for deterministic run" - assert ( - result["post_checkpoint_metrics"] == result["resumed_metrics"] - ), "Resumed training should produce same metrics as original training" + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics should match" def test_shuffling_behavior(self, dataset_factory, small_dataset_file): """Tests that shuffling changes data order between epochs but preserves the set of samples.""" @@ -251,7 +253,9 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): for sample in first_epoch_samples: first_epoch_metrics.extend(sample["metrics"]) epoch_values = [ - metric.value for metric in first_epoch_metrics if "num_epochs" in metric.key + metric.value + for metric in first_epoch_metrics + if metric.metric_name == "epoch" ] assert all( epoch_value == 0 for epoch_value in epoch_values @@ -264,7 +268,7 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): epoch_values = [ metric.value for metric in second_epoch_metrics - if "num_epochs" in metric.key + if metric.metric_name == "epoch" ] assert all( epoch_value == 1 for epoch_value in epoch_values @@ -310,7 +314,7 @@ def test_distributed_epoch_boundary_checkpointing(self): # Test multiple epoch boundaries for num_epochs in [0.9, 1.0, 2.5]: - def create_loader(): + def create_loader_and_aggregator(): dataset = HfIterableDataset( path="json", data_files=str(medium_dataset_file), @@ -318,7 +322,7 @@ def create_loader(): dataset_name="epoch_test", seed=SEED, shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultDatasetMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, ) loader = StatefulDataLoader( @@ -327,10 +331,10 @@ def create_loader(): collate_fn=collate_with_metrics, num_workers=0, ) - return loader + return loader, MetricsAggregator() - loader1 = create_loader() - loader2 = create_loader() + loader1, aggregator1 = create_loader_and_aggregator() + loader2, aggregator2 = create_loader_and_aggregator() # Calculate steps to reach desired epoch boundary samples_per_rank = MEDIUM_DATASET_SIZE // dist.get_world_size() @@ -348,9 +352,11 @@ def create_loader(): result = generate_ckpt( loader1, - steps_before_checkpoint=steps_before, - steps_after_checkpoint=steps_after, + aggregator1, + steps_before, + steps_after, resume_dataloader=loader2, + resume_aggregator=aggregator2, ) # Verify deterministic resumption - critical for distributed training @@ -369,7 +375,7 @@ def create_loader(): num_epochs - 1e-9 ) # -1e-9 so 1.0 epochs -> 0 assert ( - final_metrics["dataset/epoch_test/num_epochs"] == expected_epoch + final_metrics["train_epoch_test/num_epochs"] == expected_epoch ), f"Epoch count incorrect for {num_epochs} epochs test scenario" finally: diff --git a/tests/unit_tests/datasets/test_interleaved.py b/tests/unit_tests/datasets/test_interleaved.py index afff27ab2..0073b905e 100644 --- a/tests/unit_tests/datasets/test_interleaved.py +++ b/tests/unit_tests/datasets/test_interleaved.py @@ -28,9 +28,9 @@ import torch import torch.distributed as dist -from forge.data.datasets import HfIterableDataset, InterleavedDataset -from forge.data.metric_transform import DefaultDatasetMetricTransform +from forge.data.dataset_metrics import DefaultTrainingMetricTransform, MetricsAggregator +from forge.data.datasets import HfIterableDataset, InterleavedDataset from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader @@ -114,7 +114,7 @@ def _create_dataset( dataset_name=dataset_name, seed=SEED, shuffle_buffer_size=10 if shuffle else 0, - metric_transform=DefaultDatasetMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, **kwargs, ) @@ -299,47 +299,37 @@ def test_metrics_aggregation( [child_interleaved, ds3], seed=SEED, dataset_name="parent" ) - # Collect metrics manually instead of using old MetricsAggregator - collected_metrics = [] + aggregator = MetricsAggregator() # Process some samples total_samples = 300 for sample in islice(iter(parent_interleaved), total_samples): - if "metrics" in sample: - collected_metrics.extend(sample["metrics"]) - - # Count metrics by dataset name - ds1_samples_processed = sum( - 1 - for m in collected_metrics - if hasattr(m, "key") and "dataset/ds1/samples_processed" in m.key - ) - ds2_samples_processed = sum( - 1 - for m in collected_metrics - if hasattr(m, "key") and "dataset/ds2/samples_processed" in m.key - ) - ds3_samples_processed = sum( - 1 - for m in collected_metrics - if hasattr(m, "key") and "dataset/ds3/samples_processed" in m.key - ) + aggregator.update(sample["metrics"]) + + metrics = aggregator.get_metrics_for_logging(prefix="train") + + # Should have metrics from all three datasets, with flat keys + assert "train_ds1/samples_seen" in metrics + assert "train_ds2/samples_seen" in metrics + assert "train_ds3/samples_seen" in metrics # All datasets should have contributed samples - assert ds1_samples_processed > 0, "ds1 should have contributed samples" - assert ds2_samples_processed > 0, "ds2 should have contributed samples" - assert ds3_samples_processed > 0, "ds3 should have contributed samples" + assert metrics["train_ds1/samples_seen"] > 0 + assert metrics["train_ds2/samples_seen"] > 0 + assert metrics["train_ds3/samples_seen"] > 0 # Total samples should equal what we processed calculated_total_samples = ( - ds1_samples_processed + ds2_samples_processed + ds3_samples_processed + metrics["train_ds1/samples_seen"] + + metrics["train_ds2/samples_seen"] + + metrics["train_ds3/samples_seen"] ) assert calculated_total_samples == total_samples # Test that ratios are approximately correct based on nested weighting - ds1_ratio = ds1_samples_processed / total_samples - ds2_ratio = ds2_samples_processed / total_samples - ds3_ratio = ds3_samples_processed / total_samples + ds1_ratio = metrics["train_ds1/samples_seen"] / total_samples + ds2_ratio = metrics["train_ds2/samples_seen"] / total_samples + ds3_ratio = metrics["train_ds3/samples_seen"] / total_samples # Expected ratios based on nested weighting: # Inner weights: ds1=0.2, ds2=0.8 -> inner total=1.0 @@ -387,30 +377,32 @@ def create_interleaved(): loader1 = StatefulDataLoader( interleaved1, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics ) + aggregator1 = MetricsAggregator() # Resumed run interleaved2 = create_interleaved() loader2 = StatefulDataLoader( interleaved2, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics ) + aggregator2 = MetricsAggregator() result = generate_ckpt( loader1, + aggregator1, steps_before_checkpoint=10, steps_after_checkpoint=20, resume_dataloader=loader2, + resume_aggregator=aggregator2, ) - # Verify checkpointing and resumption work correctly - # After loading a checkpoint, training should continue identically orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] assert ( orig_post_ids == resumed_ids ), "Resumed batches should be identical for deterministic run" assert ( - result["post_checkpoint_metrics"] == result["resumed_metrics"] - ), "Resumed training should produce same metrics as original training" + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics should match" # Test sampling log functionality # Check that sampling log contains tuples of (iteration_count, dataset_name) @@ -520,7 +512,7 @@ def create_dataset(): split="train", dataset_name="ds1", shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultDatasetMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, weight=0.3, ) @@ -530,7 +522,7 @@ def create_dataset(): split="train", dataset_name="ds2", shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultDatasetMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, weight=0.7, ) @@ -540,7 +532,7 @@ def create_dataset(): split="train", dataset_name="ds3", shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultDatasetMetricTransform(), + metric_transform=DefaultTrainingMetricTransform(), num_shards_per_rank=2, weight=1.0, ) @@ -560,17 +552,19 @@ def create_dataloader(dataset): num_workers=0, # Avoid multiprocessing in distributed tests collate_fn=collate_with_metrics, ) - return loader + return loader, MetricsAggregator() # Run checkpointing test with small number of steps - loader1 = create_dataloader(create_dataset()) - loader2 = create_dataloader(create_dataset()) + loader1, aggregator1 = create_dataloader(create_dataset()) + loader2, aggregator2 = create_dataloader(create_dataset()) result = generate_ckpt( loader1, - steps_before_checkpoint=3, - steps_after_checkpoint=3, + aggregator1, + 3, + 3, # 3 steps before, 3 steps after checkpoint resume_dataloader=loader2, + resume_aggregator=aggregator2, ) # Verify deterministic resumption @@ -583,8 +577,8 @@ def create_dataloader(dataset): f"This indicates sampling state is not properly preserved." ) assert ( - result["post_checkpoint_metrics"] == result["resumed_metrics"] - ), "Resumed training should produce same metrics as original training" + result["final_metrics"] == result["resumed_metrics"] + ), "Final metrics don't match resumed metrics - aggregator state issue" # Verify sampling ratio is approximately maintained for nested structure all_ids = [] diff --git a/tests/unit_tests/datasets/test_iterable_utils.py b/tests/unit_tests/datasets/test_iterable_utils.py index 0c6d26fe3..cdeced7c7 100644 --- a/tests/unit_tests/datasets/test_iterable_utils.py +++ b/tests/unit_tests/datasets/test_iterable_utils.py @@ -7,91 +7,92 @@ from typing import Any, Optional import torch -from torch.utils.data import DataLoader +from forge.data.dataset_metrics import MetricsAggregator +from torch.utils.data import DataLoader -def collate_with_metrics(batch): - """ - Simple collate function that preserves metrics for validation. - Collects metrics from all samples in the batch and aggregates them. - Uses a simple collation that doesn't enforce same sizes for lists/tokens. - """ - # Collect metrics from all samples - batch_metrics = [] +def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: + """Simple collate that extracts metrics and pads tokens.""" + all_metrics = [] + clean_batch = [] for sample in batch: if "metrics" in sample: - batch_metrics.extend(sample.pop("metrics")) - - # Simple collation that handles variable-length sequences - collated = {} - if batch: - for key in batch[0].keys(): - values = [sample[key] for sample in batch] - if key == "tokens" or key == "labels": - # Keep as list of lists for variable length sequences - collated[key] = values - else: - # Use default collation for scalars - collated[key] = torch.utils.data.default_collate(values) - - # Add batch-level metrics key for downstream processing - if batch_metrics: - collated["metrics"] = batch_metrics + all_metrics.extend(sample.pop("metrics")) + clean_batch.append(sample) + + if not clean_batch: + return {"metrics": all_metrics} + + # Simple padding for tokens + ids = torch.tensor([item["id"] for item in clean_batch]) + tokens = torch.nn.utils.rnn.pad_sequence( + [torch.tensor(item["tokens"]) for item in clean_batch], + batch_first=True, + padding_value=-1, # Use -1 for padding to distinguish from valid IDs + ) + collated = { + "id": ids, + "tokens": tokens, + } + + # Add text field for non-tensor data + if "text" in clean_batch[0]: + collated["text"] = [item["text"] for item in clean_batch] + collated["metrics"] = all_metrics return collated def generate_ckpt( dataloader: DataLoader, + aggregator: MetricsAggregator, steps_before_checkpoint: int, steps_after_checkpoint: int, resume_dataloader: Optional[DataLoader] = None, + resume_aggregator: Optional[MetricsAggregator] = None, ) -> dict[str, Any]: """ Generates a checkpoint by running through data and saving checkpoint mid-stream. - Optionally, a second dataloader can be given to resume from checkpoint + Optionally, a second dataloader and aggregator can be given to resume from ckpt and run steps_after_checkpoint to match the first one. - Collects and aggregates metrics for test validation purposes. - Args: dataloader (DataLoader): The dataloader to test + aggregator (MetricsAggregator): The metrics aggregator to use steps_before_checkpoint (int): Number of steps to run before saving checkpoint steps_after_checkpoint (int): Number of steps to run after checkpoint resume_dataloader (Optional[DataLoader]): Optional new dataloader to test resuming. If None, returns empty resumed_batches. + resume_aggregator (Optional[MetricsAggregator]): Optional new aggregator to test resuming. + If None, returns empty resumed_metrics. Returns: - dict[str, Any]: Dict with batches and aggregated metrics for validation. + dict[str, Any]: Dict with batches/metrics from both pre and post checkpoint runs. """ iterator = iter(dataloader) - # Collect batches and metrics before and after checkpoint + # Collect batches before and after checkpoint batches = [] - all_metrics = [] # All metrics collected during the run - checkpoint_metrics = [] # Metrics collected only up to checkpoint checkpoint_state = None + metrics_at_checkpoint = {} total_steps = steps_before_checkpoint + steps_after_checkpoint for idx, batch in enumerate(iterator): batches.append(batch) - # Collect metrics for test validation + # Process metrics if "metrics" in batch: - batch_metrics = batch.pop("metrics") - all_metrics.extend(batch_metrics) - - # If we haven't reached checkpoint yet, also add to checkpoint metrics - if idx < steps_before_checkpoint: - checkpoint_metrics.extend(batch_metrics) + aggregator.update(batch.pop("metrics")) # Save checkpoint state after steps_before_checkpoint if idx == steps_before_checkpoint - 1: # -1 because idx is 0-based checkpoint_state = { "loader": dataloader.state_dict(), + "aggregator": aggregator.state_dict(), } + metrics_at_checkpoint = aggregator.get_metrics_for_logging(prefix="train") # Stop after total steps if idx == total_steps - 1: @@ -101,56 +102,43 @@ def generate_ckpt( pre_checkpoint_batches = batches[:steps_before_checkpoint] post_checkpoint_batches = batches[steps_before_checkpoint:] - # Compute metrics for post-checkpoint batches only - post_checkpoint_metrics = all_metrics[len(checkpoint_metrics) :] - - # Resume with new instance if provided + # Resume with new instances if provided resumed_batches = [] - resumed_metrics = [] - - if resume_dataloader is not None and checkpoint_state is not None: - # Test resuming with new instance + resumed_metrics = {} + + if ( + resume_dataloader is not None + and resume_aggregator is not None + and checkpoint_state is not None + ): + # Test resuming with new instances resume_dataloader.load_state_dict(checkpoint_state["loader"]) + resume_aggregator.load_state_dict(checkpoint_state["aggregator"]) resume_iterator = iter(resume_dataloader) # Collect only the post-checkpoint batches when resuming for idx, batch in enumerate(resume_iterator): resumed_batches.append(batch) - # Collect metrics from resumed batches + # Process metrics if "metrics" in batch: - batch_metrics = batch.pop("metrics") - resumed_metrics.extend(batch_metrics) + resume_aggregator.update(batch.pop("metrics")) # Stop after steps_after_checkpoint if idx == steps_after_checkpoint - 1: break + resumed_metrics = resume_aggregator.get_metrics_for_logging(prefix="train") + return { # Original run "pre_checkpoint_batches": pre_checkpoint_batches, "post_checkpoint_batches": post_checkpoint_batches, - "metrics_at_checkpoint": aggregate_metrics(checkpoint_metrics), - "post_checkpoint_metrics": aggregate_metrics(post_checkpoint_metrics), - "final_metrics": aggregate_metrics(all_metrics), + "metrics_at_checkpoint": metrics_at_checkpoint, + "final_metrics": aggregator.get_metrics_for_logging(prefix="train"), # Resumed run "resumed_batches": resumed_batches, - "resumed_metrics": aggregate_metrics(resumed_metrics), + "resumed_metrics": resumed_metrics, # Internal state for loading - only if someone needs to manually load "_checkpoint_state": checkpoint_state, } - - -def aggregate_metrics(metrics_list: list) -> dict[str, Any]: - if not metrics_list: - return {} - - accumulators = {} - - for metric in metrics_list: - key = metric.key - if key not in accumulators: - accumulators[key] = metric.reduction.accumulator_class(metric.reduction) - accumulators[key].append(metric.value) - - return {key: acc.get_value() for key, acc in accumulators.items()} diff --git a/tests/unit_tests/datasets/test_packed.py b/tests/unit_tests/datasets/test_packed.py index 56cd5ff02..352fbf703 100644 --- a/tests/unit_tests/datasets/test_packed.py +++ b/tests/unit_tests/datasets/test_packed.py @@ -14,6 +14,7 @@ import torch from forge.data.collate import collate_packed +from forge.data.dataset_metrics import MetricsAggregator from forge.data.datasets import HfIterableDataset from forge.data.datasets.packed import ( _SUPPORTS_FLEX_ATTENTION, @@ -913,7 +914,7 @@ def test_checkpoint_and_resume(self, dataset_factory): batch_size = 1 # Setup dataset factory - def create_loader(): + def create_loader_and_aggregator(): dataset = dataset_factory(samples) packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = PackedDataset( @@ -930,10 +931,11 @@ def create_loader(): loader = StatefulDataLoader( packed_dataset, batch_size=batch_size, collate_fn=collate_fn ) - return loader + aggregator = MetricsAggregator() + return loader, aggregator - loader1 = create_loader() - loader2 = create_loader() + loader1, aggregator1 = create_loader_and_aggregator() + loader2, aggregator2 = create_loader_and_aggregator() steps_before_checkpoint = 2 steps_after_checkpoint = 2 @@ -941,9 +943,11 @@ def create_loader(): # Generate checkpoint and resume result = generate_ckpt( loader1, + aggregator1, steps_before_checkpoint=steps_before_checkpoint, steps_after_checkpoint=steps_after_checkpoint, resume_dataloader=loader2, + resume_aggregator=aggregator2, ) # Verify that checkpointing and resumption work