diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 44e4485e4..e9ddc625a 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -46,7 +46,7 @@ parallelism: checkpoint: enable: true folder: ./checkpoint # The folder to save checkpoints to. - initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists. initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo last_save_in_hf: true interval: 500 @@ -56,6 +56,12 @@ activation_checkpoint: mode: selective selective_ac_option: op +metric_logging: + wandb: + project: sft-training + group: sft_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + # profiling: # enable_profiling: false diff --git a/apps/sft/main.py b/apps/sft/main.py index aa484608e..edda0b49d 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -27,6 +27,7 @@ from forge.data.datasets.packed import PackedDataset, TextPacker from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset from forge.data.tokenizer import HuggingFaceModelTokenizer +from forge.observability import get_or_create_metric_logger, record_metric, Reduce from forge.util.config import parse from monarch.actor import current_rank, current_size, endpoint @@ -77,7 +78,6 @@ def __init__(self, config: DictConfig): self.current_step = 0 self.num_training_steps = job_config.training.steps - self.metric_logger = None # TODO: fix this self.gradient_accumulation_steps = 1 # Example value, adjust as needed self._rank = current_rank().rank self._size = math.prod(current_size().values()) @@ -109,9 +109,22 @@ def _init_dist(self): os.environ.update(env) logger.info("env: {}".format(env)) + async def setup_metric_logger(self): + """Initialization happens in the main process. Here we just retrieve it""" + mlogger = await get_or_create_metric_logger() + return mlogger + + def record_batch_metrics(self, data_metrics: list): + """Since the dataloader creates new processes, we dont call `record_metric` in the dataset. + Instead, pop the metrics from the batch and record them here.""" + for metric in data_metrics: + record_metric(metric.key, metric.value, metric.reduction) + @endpoint async def setup(self): self.train_dataloader = self.setup_data() + self.mlogger = await self.setup_metric_logger() + # self.train_dataloader = self.setup_data( # self.train_config.train_dataset_config, # self.train_config.train_dataloader_config, @@ -234,7 +247,9 @@ def train_step(self, batch) -> None: # ) as grad_acc: labels = batch.pop("labels") loss = self.forward_backward(batch, labels) + loss = loss.item() + record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN) logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") # self.pbar.update(1) @@ -251,14 +266,25 @@ async def train(self) -> None: while self.current_step < self.num_training_steps: batch = next(dataloader) + + # Pop and record metrics from batch before moving to device + self.record_batch_metrics(batch.pop("metrics", [])) + record_metric("ForgeSFTRecipe/train/step", self.current_step, Reduce.MEAN) + # Move tensors to the appropriate device for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to("cuda") # TODO: hardcoded for now + self.train_step(batch) # self.profiler.step() self.current_step += 1 + # Flush metrics + if self._rank == 0: + logger.debug(f"Flushing metrics at step {self.current_step}") + await self.mlogger.flush.call_one(global_step=self.current_step) + self.checkpointer.save( curr_step=self.current_step, last_step=self.current_step == self.num_training_steps, @@ -270,16 +296,23 @@ async def train(self) -> None: async def cleanup(self) -> None: if self.checkpointer: self.checkpointer.close() - if self.metric_logger: - self.metric_logger.close() + if getattr(self, "mlogger", None): + await self.mlogger.shutdown.call_one() def __repr__(self) -> str: return "Trainer" async def run(cfg: DictConfig) -> None: - logging.info("Spawing recipe...") + + logging.info("Spawning recipe...") process_cfg = cfg.pop("processes") + + # Initialize metric logger in main process + metric_logging_cfg = cfg.get("metric_logging", {}) + mlogger = await get_or_create_metric_logger(process_name="Controller") + await mlogger.init_backends.call_one(metric_logging_cfg) + recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg) logging.info("Created recipe, running setup.") @@ -290,6 +323,7 @@ async def run(cfg: DictConfig) -> None: logging.info("Done training. Clean up") await recipe.cleanup.call() + await recipe.mesh.stop() logging.info("All done!") diff --git a/apps/sft/qwen3_8b.yaml b/apps/sft/qwen3_8b.yaml index 1c0d5bc8b..f7c4999bb 100644 --- a/apps/sft/qwen3_8b.yaml +++ b/apps/sft/qwen3_8b.yaml @@ -45,7 +45,7 @@ parallelism: checkpoint: enable: true folder: ./checkpoint # The folder to save checkpoints to. - initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists. + initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists. initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo last_save_in_hf: true interval: 500 @@ -55,6 +55,12 @@ activation_checkpoint: mode: selective selective_ac_option: op +metric_logging: + wandb: + project: sft-training + group: sft_exp_${oc.env:USER} + logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce + # profiling: # enable_profiling: false diff --git a/src/forge/data/__init__.py b/src/forge/data/__init__.py index 4347199b9..74ba663e0 100644 --- a/src/forge/data/__init__.py +++ b/src/forge/data/__init__.py @@ -4,6 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from .collate import collate_packed +from .metric_transform import DefaultDatasetMetricTransform, MetricTransform from .utils import CROSS_ENTROPY_IGNORE_IDX -__all__ = ["collate_packed", "CROSS_ENTROPY_IGNORE_IDX"] +__all__ = [ + "collate_packed", + "CROSS_ENTROPY_IGNORE_IDX", + "MetricTransform", + "DefaultDatasetMetricTransform", +] diff --git a/src/forge/data/dataset_metrics/__init__.py b/src/forge/data/dataset_metrics/__init__.py deleted file mode 100644 index 3a218e282..000000000 --- a/src/forge/data/dataset_metrics/__init__.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from .metric_agg_handlers import ( - AggregationHandler, - CategoricalCountAggHandler, - MaxAggHandler, - MeanAggHandler, - MetricState, - MinAggHandler, - StatsAggHandler, - SumAggHandler, -) -from .metric_aggregator import MetricsAggregator -from .metric_transform import ( - AggregationType, - DefaultTrainingMetricTransform, - Metric, - MetricTransform, -) - -__all__ = [ - "AggregationType", - "AggregationHandler", - "CategoricalCountAggHandler", - "DefaultTrainingMetricTransform", - "StatsAggHandler", - "MaxAggHandler", - "MeanAggHandler", - "Metric", - "MetricState", - "MetricsAggregator", - "MetricTransform", - "MinAggHandler", - "SumAggHandler", -] diff --git a/src/forge/data/dataset_metrics/metric_agg_handlers.py b/src/forge/data/dataset_metrics/metric_agg_handlers.py deleted file mode 100644 index bb3978a6b..000000000 --- a/src/forge/data/dataset_metrics/metric_agg_handlers.py +++ /dev/null @@ -1,466 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import logging -from abc import ABC, abstractmethod -from collections import Counter, deque -from dataclasses import dataclass, field -from typing import Any - -import torch - -from .metric_transform import AggregationType, Metric - -logger = logging.getLogger(__name__) - - -@dataclass -class MetricState: - """Mutable state object representing the state of a (source, metric_name) on a single rank. - - Attributes: - source (str): Name of the source, e.g. the dataset name. Used for logging and disambiguation. - metric_name (str): Name of the metric. - value (float): Current aggregated value, whose meaning depends on the aggregation type - (e.g., running sum, current max). - agg_type (AggregationType): Aggregation type. - metadata (dict[str, Any]): Additional state like count, list of values, etc. - """ - - source: str - metric_name: str - value: float - agg_type: AggregationType - metadata: dict[str, Any] = field(default_factory=dict) - - -class AggregationHandler(ABC): - """Base class for handling metric aggregation. - - Each handler implements a specific aggregation strategy (SUM, MEAN, STATS, etc.) - and manages the complete lifecycle: initialization, updates, local finalization, - and distributed reduction. Handlers also handle serialization for checkpointing. - - The handler architecture allows pluggable aggregation strategies while maintaining - consistent interfaces for the MetricsAggregator. - """ - - @abstractmethod - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - """Create a new MetricState for a (source, metric_name) pair. - - Args: - source (str): Name of the source, e.g. the dataset name. Used for logging and disambiguation. - metric_name (str): Name of the metric. - agg_type (AggregationType): Aggregation type. - - Returns: - MetricState: New MetricState for this (source, metric_name) pair. - """ - pass - - @abstractmethod - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - """Update cumulative MetricState with new metric info. - - Args: - local_agg_metric (MetricState): State of the aggregation for this metric in the local rank. - metric (Metric): Input metric info. - """ - pass - - @abstractmethod - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - """ - Computes the final value from the locally aggregated state. For example, for mean - it would mean to divide the tracked sum by the tracked count. - - This method may expand a single metric into multiple, for instance, - a list of numbers into mean, min, max, and percentiles. - - Args: - local_agg_metric (MetricState): The locally aggregated metric state to finalize. - - Returns: - list[MetricState]: List of finalized metric states. - """ - pass - - @abstractmethod - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - """ - Merge MetricStates from all ranks into final result. For example, for 'sum', it would mean to - sum the values from all ranks. - - Args: - local_agg_metrics (list[MetricState]): list of MetricStates from all ranks for a specific - (source, metric_name) tuple after computing finalize_local_agg. - - Returns: - MetricState: Final result for this (source, metric_name) pair. - """ - pass - - def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Convert handler-specific metadata to serializable format. Override this when using - non-serializable types like deque or Counter. For example, convert deque to list, Counter to dict. - - Args: - metadata (dict[str, Any]): AggHandler-specific metadata. - - Returns: - dict[str, Any]: Serializable metadata. - """ - return metadata.copy() - - def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Restore handler-specific metadata from serialized format. Override this to reverse the - serialize_metadata transformation. For example, convert list back to deque, dict back to Counter. - - Args: - metadata (dict[str, Any]): AggHandler-specific metadata. - - Returns: - dict[str, Any]: Deserialized metadata. - """ - return metadata.copy() - - -class SumAggHandler(AggregationHandler): - """AggHandler for SUM aggregation. Initializes with 0.0 and accumulates metric values.""" - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=0.0, - agg_type=agg_type, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - if not isinstance(metric.value, (int, float)): - raise ValueError( - f"SumAggHandler expects numeric values, got {type(metric.value)}" - ) - local_agg_metric.value += metric.value - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - return [local_agg_metric] - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - if not local_agg_metrics: - raise ValueError("Cannot aggregate empty list of metrics") - - total = sum(metric.value for metric in local_agg_metrics) - return MetricState( - source=local_agg_metrics[0].source, - metric_name=local_agg_metrics[0].metric_name, - value=total, - agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy(), - ) - - -class MaxAggHandler(AggregationHandler): - """AggHandler for MAX aggregation. Tracks maximum value across all updates.""" - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=float("-inf"), - agg_type=agg_type, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - if not isinstance(metric.value, (int, float)): - raise ValueError( - f"MaxAggHandler expects numeric values, got {type(metric.value)}" - ) - local_agg_metric.value = max(local_agg_metric.value, metric.value) - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - return [local_agg_metric] - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - max_value = max(r.value for r in local_agg_metrics) - return MetricState( - source=local_agg_metrics[0].source, - metric_name=local_agg_metrics[0].metric_name, - value=max_value, - agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy(), - ) - - -class MinAggHandler(AggregationHandler): - """AggHandler for MIN aggregation. Tracks minimum value across all updates.""" - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=float("inf"), - agg_type=agg_type, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - if not isinstance(metric.value, (int, float)): - raise ValueError( - f"MinAggHandler expects numeric values, got {type(metric.value)}" - ) - local_agg_metric.value = min(local_agg_metric.value, metric.value) - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - return [local_agg_metric] - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - min_value = min(r.value for r in local_agg_metrics) - return MetricState( - source=local_agg_metrics[0].source, - metric_name=local_agg_metrics[0].metric_name, - value=min_value, - agg_type=local_agg_metrics[0].agg_type, - metadata=local_agg_metrics[0].metadata.copy(), - ) - - -class MeanAggHandler(AggregationHandler): - """AggHandler for MEAN aggregation. Maintains running sum and count to compute average.""" - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=0.0, - agg_type=agg_type, - metadata={"sum": 0.0, "count": 0}, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - local_agg_metric.metadata["sum"] += metric.value - local_agg_metric.metadata["count"] += 1 - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - count = local_agg_metric.metadata["count"] - local_agg_metric.value = ( - local_agg_metric.metadata["sum"] / count if count > 0 else 0.0 - ) - return [local_agg_metric] - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - total_sum = sum(metric.metadata["sum"] for metric in local_agg_metrics) - total_count = sum(metric.metadata["count"] for metric in local_agg_metrics) - - return MetricState( - source=local_agg_metrics[0].source, - metric_name=local_agg_metrics[0].metric_name, - value=total_sum / total_count if total_count > 0 else 0.0, - agg_type=local_agg_metrics[0].agg_type, - metadata={"sum": total_sum, "count": total_count}, - ) - - -class StatsAggHandler(AggregationHandler): - """AggHandler for STATS aggregation. Maintains a sliding window of values - and expands into multiple statistical metrics (mean, min, max, percentiles, std). - - Note: Percentiles and standard deviation are approximated in distributed settings by averaging local - percentiles and standard deviations across ranks. This is mathematically imprecise but provides a - reasonable approximation for monitoring purposes. - - Args: - window_size (int): Maximum number of recent values to retain for statistics. - - Raises: - ValueError: If window_size is not positive. - """ - - def __init__(self, window_size: int = 1000): - if window_size <= 0: - raise ValueError(f"window_size must be positive, got {window_size}") - self.window_size = window_size - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=0.0, - agg_type=agg_type, - metadata={"values": deque(maxlen=self.window_size)}, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - local_agg_metric.metadata["values"].append(metric.value) - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - values = list(local_agg_metric.metadata["values"]) - if not values: - return [] - - values_tensor = torch.tensor(values, dtype=torch.float64) - n = len(values_tensor) - - # Compute stats from the tensor - sum_val = torch.sum(values_tensor).item() - mean_val = sum_val / n - min_val = torch.min(values_tensor).item() - max_val = torch.max(values_tensor).item() - - # Compute percentiles - percentile_definitions = torch.tensor([0.05, 0.5, 0.95], dtype=torch.float64) - p05_val, p50_val, p95_val = torch.quantile( - values_tensor, percentile_definitions - ).tolist() - - # Return multiple MetricStates with proper agg_types for distributed reduction - # NOTE: Percentiles use MEAN aggregation which approximates global percentiles - # by averaging local percentiles. - metrics = [ - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_mean", - value=mean_val, - agg_type=AggregationType.MEAN, - metadata={"sum": sum_val, "count": n}, - ), - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_min", - value=min_val, - agg_type=AggregationType.MIN, - metadata={}, - ), - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_max", - value=max_val, - agg_type=AggregationType.MAX, - metadata={}, - ), - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_p05", - value=p05_val, - agg_type=AggregationType.MEAN, - metadata={"sum": p05_val, "count": 1}, - ), - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_p50", - value=p50_val, - agg_type=AggregationType.MEAN, - metadata={"sum": p50_val, "count": 1}, - ), - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_p95", - value=p95_val, - agg_type=AggregationType.MEAN, - metadata={"sum": p95_val, "count": 1}, - ), - ] - - # Standard deviation is only well-defined for n > 1 - if n > 1: - std_val = torch.std(values_tensor).item() - metrics.append( - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_stat_std", - value=std_val, - agg_type=AggregationType.MEAN, - metadata={"sum": std_val, "count": 1}, - ) - ) - return metrics - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - raise NotImplementedError( - "Metrics with AggregationType.STATS were converted to other " - "AggregationTypes for distributed reduction. finalize_dist_agg should not be called." - ) - - def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Convert deque to list for serialization.""" - serialized = metadata.copy() - if "values" in serialized: - serialized["values"] = list(serialized["values"]) - return serialized - - def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Convert list back to deque.""" - deserialized = metadata.copy() - if "values" in deserialized: - deserialized["values"] = deque( - deserialized["values"], maxlen=self.window_size - ) - return deserialized - - -class CategoricalCountAggHandler(AggregationHandler): - """AggHandler for CATEGORICAL_COUNT aggregation. Counts occurrences of categorical values - and expands into individual count metrics for each category.""" - - def initialize_metric_state( - self, source: str, metric_name: str, agg_type: AggregationType - ) -> MetricState: - return MetricState( - source=source, - metric_name=metric_name, - value=0.0, - agg_type=agg_type, - metadata={"counts": Counter()}, - ) - - def update(self, local_agg_metric: MetricState, metric: Metric) -> None: - local_agg_metric.metadata["counts"][metric.value] += 1 - - def finalize_local_agg(self, local_agg_metric: MetricState) -> list[MetricState]: - # Expand categorical counts into individual metrics - results = [] - for category, count in local_agg_metric.metadata["counts"].items(): - results.append( - MetricState( - source=local_agg_metric.source, - metric_name=f"{local_agg_metric.metric_name}_count_{category}", - value=count, - agg_type=AggregationType.SUM, - ) - ) - return results - - def finalize_dist_agg(self, local_agg_metrics: list[MetricState]) -> MetricState: - raise NotImplementedError( - "Metrics with AggregationType.CATEGORICAL_COUNT were converted to other " - "AggregationType.SUM for distributed reduction. finalize_dist_agg should not be called." - ) - - def serialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Convert Counter to dict for serialization.""" - serialized = metadata.copy() - if "counts" in serialized: - serialized["counts"] = dict(serialized["counts"]) - return serialized - - def deserialize_metadata(self, metadata: dict[str, Any]) -> dict[str, Any]: - """Convert dict back to Counter.""" - deserialized = metadata.copy() - if "counts" in deserialized: - deserialized["counts"] = Counter(deserialized["counts"]) - return deserialized diff --git a/src/forge/data/dataset_metrics/metric_aggregator.py b/src/forge/data/dataset_metrics/metric_aggregator.py deleted file mode 100644 index 40d8075ce..000000000 --- a/src/forge/data/dataset_metrics/metric_aggregator.py +++ /dev/null @@ -1,344 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import ast -import logging -from collections import defaultdict -from typing import Any, Union - -import torch.distributed as dist - -from .metric_agg_handlers import ( - AggregationHandler, - CategoricalCountAggHandler, - MaxAggHandler, - MeanAggHandler, - MetricState, - MinAggHandler, - StatsAggHandler, - SumAggHandler, -) -from .metric_transform import AggregationType, Metric - -logger = logging.getLogger(__name__) - - -class MetricsAggregator: - """Aggregates metrics across datasets and distributed ranks using pluggable handlers. - - This class uses a handler-based strategy, where each aggregation type (SUM, MEAN, etc.) - has a corresponding AggregationHandler. It maintains a single state object for each - (source, metric_name) pair. - - Internal State Visualization: - { - ("alpaca", "tokens_seen"): MetricState(value=200.0, agg_type=SUM, ...), - ("alpaca", "avg_loss"): MetricState(value=0.01, agg_type=MEAN, metadata={'sum': ..., 'count': ...}), - ("slim_orca", "seq_len"): MetricState(agg_type=STATS, metadata={'values': deque([...])}), - } - - When preparing metrics for logging, the aggregator follows a two-phase process: - 1. Local Aggregation: Each rank aggregates its metrics independently - 2. Distributed Reduction: If in distributed mode, results are combined across ranks - - The aggregator's state is checkpointable, allowing training resumption. - - Args: - dist_window_size (int): Window size for StatsAggHandler tracking. - - Example: - >>> from forge.data.metrics import MetricsAggregator, Metric, AggregationType - >>> - >>> aggregator = MetricsAggregator() - >>> - >>> # Sample metrics from different batches - >>> batch1_metrics = [ - ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), - ... ] - >>> - >>> batch2_metrics = [ - ... Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - ... Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), - ... ] - >>> - >>> # Update with metrics - >>> aggregator.update(batch1_metrics) - >>> aggregator.update(batch2_metrics) - >>> - >>> # Get final results - >>> results = aggregator.get_metrics_for_logging(prefix="train") - >>> # {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} - - Raises: - ValueError: If dist_window_size is not positive. - """ - - def __init__(self, dist_window_size: int = 1000): - if dist_window_size <= 0: - raise ValueError( - f"dist_window_size must be positive, got {dist_window_size}" - ) - - # Storage: {(source, metric_name): MetricState} - O(unique metrics) not O(samples) - self._metric_states: dict[tuple[str, str], MetricState] = {} - self._dist_window_size = dist_window_size - - # Track aggregation types for validation - prevents same metric name with different agg types - self._metric_agg_types: dict[tuple[str, str], AggregationType] = {} - - # Create handler registry - all handlers initialized upfront - self._handlers: dict[AggregationType, AggregationHandler] = { - AggregationType.SUM: SumAggHandler(), - AggregationType.MAX: MaxAggHandler(), - AggregationType.MIN: MinAggHandler(), - AggregationType.MEAN: MeanAggHandler(), - AggregationType.STATS: StatsAggHandler(dist_window_size), - AggregationType.CATEGORICAL_COUNT: CategoricalCountAggHandler(), - } - - def _validate_metric_consistency(self, metric: Union[Metric, MetricState]) -> None: - """Validate that metric name uses consistent aggregation type.""" - metric_key = (metric.source, metric.metric_name) - metric_name = metric.metric_name - - if metric_key in self._metric_agg_types: - existing_agg_type = self._metric_agg_types[metric_key] - if existing_agg_type != metric.agg_type: - raise ValueError( - f"Metric '{metric_name}' in dataset '{metric.source}' " - f"is already registered with aggregation type {existing_agg_type.value}, " - f"but a handler or user code tried to use it with type {metric.agg_type.value}. " - f"Use different metric names for different aggregation types." - ) - else: - # Track this metric's aggregation type - self._metric_agg_types[metric_key] = metric.agg_type - - def register_handler( - self, agg_type: AggregationType, handler: AggregationHandler - ) -> None: - """Register custom aggregation handler for specified type. - - Args: - agg_type (AggregationType): The aggregation type to handle - handler (AggregationHandler): Handler instance implementing the AggregationHandler interface - """ - # Warn if replacing a handler that's already in use - if agg_type in self._handlers and any( - state.agg_type == agg_type for state in self._metric_states.values() - ): - logger.warning( - f"Replacing handler for {agg_type} - aggregation type already in use by existing metrics. " - f"This may affect existing metric behavior." - ) - - self._handlers[agg_type] = handler - - def update(self, metrics: list[Metric]) -> None: - """Update (source, metric_name) metric state with new values. - - Args: - metrics (list[Metric]): List of metrics to update the state with - - Raises: - ValueError: If no handler is registered for a metric's aggregation type, - or if metric name conflicts with existing aggregation type. - """ - for metric in metrics: - # Same metric name must use same aggregation type - self._validate_metric_consistency(metric) - - metric_key = (metric.source, metric.metric_name) - handler = self._handlers.get(metric.agg_type) - - if handler is None: - raise ValueError( - f"No handler registered for aggregation type: {metric.agg_type}" - ) - - if metric_key not in self._metric_states: - self._metric_states[metric_key] = handler.initialize_metric_state( - metric.source, metric.metric_name, metric.agg_type - ) - - local_agg_metric = self._metric_states[metric_key] - handler.update(local_agg_metric, metric) # Mutates local_agg_metric - - def get_metrics_for_logging(self, prefix: str = "data") -> dict[str, float]: - """Get final metrics for logging in standard format. - - Args: - prefix (str): Prefix for metric names in the returned dictionary - - Returns: - dict[str, float]: Dictionary with keys like "{prefix}_{source}/{metric_name}" - and float values. For example, with `prefix="train"`, `source="alpaca"`, - `metric_name="loss"`, the key would be `train_alpaca/loss`. - """ - final_results = self._compute_unified_metrics() - - return { - f"{prefix}_{result.source}/{result.metric_name}": result.value - for result in final_results - } - - def _compute_unified_metrics(self) -> list[MetricState]: - """ - Compute metrics handling both local and distributed cases uniformly. - - Returns: - list[MetricState]: Final results ready for logging - """ - # Step 1: Get local results from all handlers (may expand stats/categoricals) - prepared_results = [] - for local_agg_metric in self._metric_states.values(): - handler = self._handlers[local_agg_metric.agg_type] - generated_metrics = handler.finalize_local_agg(local_agg_metric) - - # Validate each newly generated metric state immediately - for gen_metric in generated_metrics: - self._validate_metric_consistency(gen_metric) - - prepared_results.extend(generated_metrics) - - # Step 2: Apply distributed reduction if needed - if dist.is_initialized() and dist.get_world_size() > 1: - prepared_results = self._finalize_dist_agg(prepared_results) - - return prepared_results - - def _finalize_dist_agg( - self, local_agg_metrics: list[MetricState] - ) -> list[MetricState]: - """Apply distributed reduction to local results. - - Args: - local_agg_metrics (list[MetricState]): (source, metric_name) metric pairs from this rank - - Returns: - list[MetricState]: Reduced results combining all ranks - """ - world_size = dist.get_world_size() - - # Gather all results from all ranks - all_results = [None] * world_size - dist.all_gather_object(all_results, local_agg_metrics) - - # Group by (source, metric_name) for reduction - grouped = defaultdict(list) - for rank_results in all_results: - if rank_results: # Handle ranks with no metrics - for result in rank_results: - result_key = (result.source, result.metric_name) - grouped[result_key].append(result) - - # Apply handler-specific distributed reduction - reduced_results = [] - for result_key, results_list in grouped.items(): - if not results_list: - continue # Skip empty groups - - # All results for a key should have same agg_type - agg_type = results_list[0].agg_type - handler = self._handlers[agg_type] - reduced_result = handler.finalize_dist_agg(results_list) - reduced_results.append(reduced_result) - - return reduced_results - - def state_dict(self) -> dict[str, Any]: - """Serialize aggregator state for checkpointing. - - Returns: - dict[str, Any]: Serializable dictionary containing all aggregator state - """ - serializable_state = {} - required_agg_types = set() # Track aggregation types used in saved states - - for metric_key, local_agg_metric in self._metric_states.items(): - # Get handler for this result's aggregation type - handler = self._handlers[local_agg_metric.agg_type] - required_agg_types.add(local_agg_metric.agg_type) - - # Convert MetricState to serializable dict - result_dict = { - "source": local_agg_metric.source, - "metric_name": local_agg_metric.metric_name, - "value": local_agg_metric.value, - "agg_type": local_agg_metric.agg_type, - "metadata": handler.serialize_metadata(local_agg_metric.metadata), - } - - # Convert tuple key to string for JSON compatibility - serializable_state[str(metric_key)] = result_dict - - return { - "state": serializable_state, - "dist_window_size": self._dist_window_size, - "required_agg_types": list( - required_agg_types - ), # Save which handlers are needed - # Save which aggregation types are used for each metric - "metric_agg_types": { - str(k): v.value for k, v in self._metric_agg_types.items() - }, - } - - def load_state_dict(self, state_dict: dict[str, Any]) -> None: - """Load aggregator state from checkpoint. - - Args: - state_dict (dict[str, Any]): Dictionary containing serialized aggregator state - - Raises: - ValueError: If required handlers are missing after checkpoint restore - """ - self._dist_window_size = state_dict.get("dist_window_size", 1000) - - # Sanity check: Ensure all required handlers are available - required_agg_types = state_dict.get("required_agg_types", []) - missing_handlers = [] - for agg_type in required_agg_types: - if agg_type not in self._handlers: - missing_handlers.append(agg_type) - - if missing_handlers: - raise ValueError( - f"Missing handlers for aggregation types: {missing_handlers}. " - f"Custom handlers must be re-registered before checkpoint restore." - ) - - deserialized_state = {} - for key_str, result_dict in state_dict["state"].items(): - # Convert string keys back to tuples - metric_key = ast.literal_eval(key_str) - - # Get handler for this aggregation type - agg_type = result_dict["agg_type"] - handler = self._handlers[agg_type] - - # Restore metadata using handler-specific deserialization - metadata = handler.deserialize_metadata(result_dict["metadata"]) - - # Create MetricState from dict - local_agg_metric = MetricState( - source=result_dict["source"], - metric_name=result_dict["metric_name"], - value=result_dict["value"], - agg_type=result_dict["agg_type"], - metadata=metadata, - ) - - deserialized_state[metric_key] = local_agg_metric - - self._metric_states = deserialized_state - - # Restore validation state - self._metric_agg_types = {} - for key_str, agg_type_str in state_dict.get("metric_agg_types", {}).items(): - key = ast.literal_eval(key_str) - self._metric_agg_types[key] = AggregationType(agg_type_str) diff --git a/src/forge/data/dataset_metrics/metric_transform.py b/src/forge/data/dataset_metrics/metric_transform.py deleted file mode 100644 index a9af39ade..000000000 --- a/src/forge/data/dataset_metrics/metric_transform.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import Enum -from typing import Any, Union - - -@dataclass(frozen=True) -class Metric: - source: str - metric_name: str - value: Union[int, float, str] - agg_type: "AggregationType" - - -class AggregationType(Enum): - """Defines how a metric's value should be aggregated by the MetricsAggregator. - - Each type corresponds to a specific AggregationHandler that implements the logic - for initialization, updates, and distributed reduction. - """ - - SUM = "sum" - MEAN = "mean" - STATS = "distribution" - CATEGORICAL_COUNT = "categorical_count" - MAX = "max" - MIN = "min" - - -class MetricTransform(ABC): - """Applied to each dataset sample to generate per-sample metrics for training tracking. - - Creates Metric objects that are later aggregated by MetricsAggregator. This separation - of concerns ensures metrics are correctly aggregated even with multiple dataloader - workers and in distributed settings. - - The transform must be configured with a source via set_source() before use. - Each call to __call__ adds metrics to the sample's "metrics" key. - - Example: - >>> transform = DefaultTrainingMetricTransform() - >>> transform.set_source("alpaca") - >>> sample = {"tokens": [1, 2, 3]} - >>> result = transform(sample) - >>> # result["metrics"] contains list of Metric objects - """ - - def set_source(self, source: str) -> None: - """Called by the dataset to set the namespace for metrics. - - This is used to differentiate metrics from multiple datasets, for example, - "alpaca/tokens_seen" vs. "slim_orca/tokens_seen". - - Args: - source (str): Name of the dataset, used for logging and disambiguation. - """ - self.source = source - - @abstractmethod - def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: - """Generate metrics for a single sample. - - Args: - sample (dict[str, Any]): The sample dictionary to generate metrics from - - Returns: - list[Metric]: List of metrics generated for this sample - - Raises: - NotImplementedError: If subclass does not implement this method. - """ - raise NotImplementedError("Subclasses must implement _generate_metrics method") - - def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: - if not hasattr(self, "source"): - raise RuntimeError( - "'transform.set_source' must be called before using the transform." - ) - - # Generate metrics for this sample - metrics = self._generate_metrics(sample) - - # Add to existing metrics list or create new one - if "metrics" not in sample: - sample["metrics"] = [] - sample["metrics"].extend(metrics) - return sample - - -class DefaultTrainingMetricTransform(MetricTransform): - """Generates common training metrics: samples seen, tokens seen, and sequence length. - - This transform detects the token key in a sample, checking for "tokens" - first and then falling back to "input_ids". - - For details on the base class behavior, see MetricTransform. - - Tracked metrics: - - samples_seen: Cumulative count of samples processed (SUM aggregation) - - tokens_seen: Cumulative sum of all tokens processed (SUM aggregation) - - seq_len: Distribution stats of sequence lengths (STATS aggregation) - - Example: - >>> transform = DefaultTrainingMetricTransform() - >>> transform.set_source("alpaca") - >>> - >>> sample = {"tokens": [1, 2, 3, 4, 5]} # 5 tokens - >>> metrics = transform._generate_metrics(sample) - >>> # This generates the following Metric objects: - >>> # [ - >>> # Metric(source="alpaca", metric_name="samples_seen", value=1, agg_type=AggregationType.SUM), - >>> # Metric(source="alpaca", metric_name="tokens_seen", value=5, agg_type=AggregationType.SUM), - >>> # Metric(source="alpaca", metric_name="seq_len", value=5, agg_type=AggregationType.STATS) - >>> # ] - """ - - def _generate_metrics(self, sample: dict[str, Any]) -> list[Metric]: - # Determine token key - token_key = "tokens" if "tokens" in sample else "input_ids" - token_len = len(sample.get(token_key, [])) - - # Create metrics for this sample - return [ - Metric( - source=self.source, - metric_name="samples_seen", - value=1, - agg_type=AggregationType.SUM, - ), - Metric( - source=self.source, - metric_name="tokens_seen", - value=token_len, - agg_type=AggregationType.SUM, - ), - Metric( - source=self.source, - metric_name="seq_len", - value=token_len, - agg_type=AggregationType.STATS, - ), - ] diff --git a/src/forge/data/dataset_metrics/readme.md b/src/forge/data/dataset_metrics/readme.md deleted file mode 100644 index 76ec424e3..000000000 --- a/src/forge/data/dataset_metrics/readme.md +++ /dev/null @@ -1,176 +0,0 @@ -# forge Metrics Module - -## Overview - -The metrics module provides a robust system for tracking and aggregating training metrics across multiple datasets and distributed environments. It follows a **strategy pattern** design with pluggable aggregation handlers to efficiently handle different types of metrics. - -## Architecture Overview - -``` -┌────────────────────────────────────────────────────┐ -│ Training Loop │ -└─────────────────────┬──────────────────────────────┘ - │ -┌─────────────────────▼──────────────────────────────┐ -│ MetricTransform │ -│ • Applied to each sample │ -│ • Generates per-sample metrics │ -│ • Examples: tokens_seen, seq_len, samples_seen │ -└─────────────────────┬──────────────────────────────┘ - │ list[Metric] -┌─────────────────────▼──────────────────────────────┐ -│ MetricsAggregator │ -│ • Aggregates metrics across samples and ranks │ -│ • Uses pluggable AggregationHandlers │ -│ • Handles distributed reduction │ -└─────────────────────┬──────────────────────────────┘ - │ {prefix}_{source}/{metric_name} # prefix is "train", "val", etc. -┌─────────────────────▼──────────────────────────────┐ -│ Logging System │ -│ • W&B, TensorBoard, etc. │ -│ • Gets formatted metrics ready for logging │ -└────────────────────────────────────────────────────┘ -``` - -## File Structure - -- **`metric_transform.py`**: Defines `Metric`, `AggregationType`, and transform classes -- **`metric_agg_handlers.py`**: Aggregation strategy implementations -- **`metric_aggregator.py`**: Main aggregator orchestrating the handlers - -## Customizing metrics - -- **Custom transforms**: Extend `MetricTransform` for domain-specific metrics -- **Handler registration**: Register custom handlers for specialized aggregation needs - -####### -## TODO -## Move this from here to website docs -####### - -## Core Components - -### 1. MetricTransform -Generates per-sample metrics during data processing. - -**Key Features:** -- Applied to each sample in the dataset -- Creates `Metric` objects with dataset name, metric name, value, and aggregation type -- Handles dataset namespacing for multi-dataset scenarios - -**Example Usage:** -```python -from forge.data.metrics import DefaultTrainingMetricTransform, AggregationType - -transform = DefaultTrainingMetricTransform() -transform.set_source("alpaca") - -# Applied to each sample -sample = {"tokens": [1, 2, 3, 4, 5]} -sample = transform(sample) -# sample["metrics"] now contains: -# [ -# Metric(source="alpaca", name="samples_seen", value=1, agg_type=AggregationType.SUM), -# Metric(source="alpaca", name="tokens_seen", value=5, agg_type=AggregationType.SUM), -# Metric(source="alpaca", name="seq_len", value=5, agg_type=AggregationType.STATS) -# ] -``` - -### 2. MetricsAggregator -Efficiently aggregates metrics across samples and distributed ranks. - -**Key Features:** -- Handler-based strategy pattern for different aggregation types -- Distributed-aware with automatic rank reduction -- Checkpointable state for training resumption -- Keep track of (metric, dataset) pairs - -**Aggregation Types (at the time of writing):** -- `SUM`: Cumulative totals (e.g., total tokens processed) -- `MEAN`: Running averages (e.g., average loss) -- `MAX/MIN`: Extrema tracking (e.g., max sequence length seen) -- `STATS`: Statistical summaries (mean, min, max, percentiles) -- `CATEGORICAL_COUNT`: Category cumulative counts (e.g. num of samples from a given category) - -**Example Usage:** -```python -from forge.data.metrics import MetricsAggregator, Metric, AggregationType - -# Create aggregator -aggregator = MetricsAggregator() - -# Sample metrics from different batches -batch1_metrics = [ - Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), -] - -batch2_metrics = [ - Metric("alpaca", "tokens_seen", 100, AggregationType.SUM), - Metric("alpaca", "avg_tokens_seen", 100, AggregationType.MEAN), -] - -# Update with metrics -aggregator.update(batch1_metrics) -aggregator.update(batch2_metrics) - -# Get final results -results = aggregator.get_metrics_for_logging(prefix="train") -# {"train_alpaca/tokens_seen": 200.0, "train_alpaca/avg_tokens_seen": 100.0} -``` - -### 3. AggregationHandlers -Pluggable strategies for different aggregation patterns. - -``` -AggregationHandler (ABC) -├── SumAggHandler # value += metric.value -├── MeanAggHandler # tracks sum and count -├── MaxAggHandler # value = max(value, metric.value) -├── MinAggHandler # value = min(value, metric.value) -├── StatsAggHandler # maintains value window + stats -└── CategoricalCountAggHandler # Counter for categories -``` - -**Custom Handler Example:** -```python -class CustomAggHandler(AggregationHandler): - def initialize_metric_state(self, source, metric_name, agg_type): - return MetricState( - source=source, - metric_name=metric_name, - value=, # should change - agg_type=agg_type, - metadata={} # may need to change - ) - - def update(self, local_agg_metric, metric): - ... - - def finalize_local_agg(self, local_agg_metric): - ... - - def finalize_dist_agg(self, local_agg_metrics): - ... - -# Register with aggregator -aggregator.register_handler(AggregationType.CUSTOM, CustomAggHandler()) -``` - -## Distributed Training Support - -The metrics system automatically handles distributed environments: - -1. **Local Aggregation**: Each rank aggregates its own metrics -2. **Distributed Reduction**: Results are combined across ranks using `all_gather_object` -3. **Type-Aware Reduction**: Each aggregation type uses appropriate reduction (sum, mean, max, etc.) - -**Distributed Flow:** -``` -Rank 0: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] -Rank 1: [(ds1, metric1), (ds1, metric2)] → LocalAgg → [(ds1, metric1), (ds1, metric2)] - ↓ - AllGather + Reduce - ↓ - Final Results [(ds1, metric1), (ds1, metric2)] -``` diff --git a/src/forge/data/datasets/hf_dataset.py b/src/forge/data/datasets/hf_dataset.py index 799dd89b9..d7b36fe68 100644 --- a/src/forge/data/datasets/hf_dataset.py +++ b/src/forge/data/datasets/hf_dataset.py @@ -12,12 +12,8 @@ from datasets import load_dataset from datasets.distributed import split_dataset_by_node -from forge.data.dataset_metrics import ( - AggregationType, - DefaultTrainingMetricTransform, - Metric, - MetricTransform, -) +from forge.data.metric_transform import DefaultDatasetMetricTransform, MetricTransform +from forge.observability.metrics import Metric, Reduce from .dataset import DatasetInfo, InfiniteTuneIterableDataset @@ -85,7 +81,7 @@ def __init__( self._weight = weight if weight is not None else 1.0 # Create default transform if not provided - self._metric_transform = metric_transform or DefaultTrainingMetricTransform() + self._metric_transform = metric_transform or DefaultDatasetMetricTransform() # Auto-generate dataset name if not provided if dataset_name is None: @@ -239,15 +235,16 @@ def __iter__(self) -> Iterator[dict[str, Any]]: # Track the number of epochs completed for each dataset. This is # especially useful when interleaving multiple datasets, but # also necessary to track dataset-level metrics. - metric_num_epochs = Metric( - source=self.info.name, - metric_name="num_epochs", - value=self._num_epochs, - agg_type=AggregationType.MAX, - ) if "metrics" not in sample: sample["metrics"] = [] - sample["metrics"].append(metric_num_epochs) + + sample["metrics"].append( + Metric( + key=f"dataset/{self.info.name}/num_epochs", + value=self._num_epochs, + reduction=Reduce.MAX, + ) + ) samples_yielded += 1 yield sample diff --git a/src/forge/data/datasets/packed.py b/src/forge/data/datasets/packed.py index d09c158c0..93a21b85e 100644 --- a/src/forge/data/datasets/packed.py +++ b/src/forge/data/datasets/packed.py @@ -16,7 +16,7 @@ from torchdata.stateful_dataloader import Stateful from forge.data import CROSS_ENTROPY_IGNORE_IDX -from forge.data.dataset_metrics import AggregationType, Metric +from forge.observability.metrics import Metric, Reduce from .dataset import DatasetInfo, InfiniteTuneIterableDataset @@ -605,13 +605,13 @@ def finalize_pack( # Add padding percentage metric if target_tokens_per_pack > 0: padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2) - padding_metric = Metric( - source=self.dataset_name, - metric_name="pct_of_tokens_padded", - value=padding_pct, - agg_type=AggregationType.MEAN, + pack["metrics"].append( + Metric( + key=f"dataset/{self.dataset_name}/pct_of_tokens_padded", + value=padding_pct, + reduction=Reduce.MEAN, + ) ) - pack["metrics"].append(padding_metric) # Concatenate tensor lists and handle other keys result = { @@ -635,7 +635,7 @@ def finalize_pack( if pack["input_pos"] else torch.empty(0, dtype=torch.long) ), - # "metrics": pack["metrics"], + "metrics": pack["metrics"], } # Handle arbitrary keys that aren't tensors - keep as lists @@ -853,13 +853,13 @@ def finalize_pack( # Add padding percentage metric if target_tokens_per_pack > 0: padding_pct = round(num_padding * 100 / target_tokens_per_pack, 2) - padding_metric = Metric( - source=self.dataset_name, - metric_name="pct_of_tokens_padded", - value=padding_pct, - agg_type=AggregationType.MEAN, + pack["metrics"].append( + Metric( + key=f"dataset/{self.dataset_name}/pct_of_tokens_padded", + value=padding_pct, + reduction=Reduce.MEAN, + ) ) - pack["metrics"].append(padding_metric) # Concatenate tensor lists and handle other keys result = { diff --git a/src/forge/data/datasets/sft_dataset.py b/src/forge/data/datasets/sft_dataset.py index 3a2574643..00278c1e5 100644 --- a/src/forge/data/datasets/sft_dataset.py +++ b/src/forge/data/datasets/sft_dataset.py @@ -9,7 +9,7 @@ import torch from forge.data import CROSS_ENTROPY_IGNORE_IDX -from forge.data.dataset_metrics import DefaultTrainingMetricTransform +from forge.data.metric_transform import DefaultDatasetMetricTransform from forge.data.utils import mask_messages, TuneMessage from .hf_dataset import HfIterableDataset @@ -198,7 +198,7 @@ def sft_iterable_dataset( message_transform=message_transform, model_transform=model_transform, output_transform=output_transform, - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), shuffle_buffer_size=shuffle_buffer_size, weight=weight, seed=seed, diff --git a/src/forge/data/metric_transform.py b/src/forge/data/metric_transform.py new file mode 100644 index 000000000..cbbc04020 --- /dev/null +++ b/src/forge/data/metric_transform.py @@ -0,0 +1,113 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +from forge.observability.metrics import Metric, Reduce + + +class MetricTransform: + """ + Base class for transforms that collect observability metrics from dataset samples. + + This class provides a foundation for implementing dataset-level metric collection + during data processing pipelines. Subclasses should override the __call__ method + to add specific metrics to each sample that passes through the transform. + + Metrics are collected as `forge.observability.metrics.Metric` objects and made available + in batch["metrics"]. + + Attributes: + source (str, optional): The source name for metrics, typically the dataset name. + This is used as a prefix in metric keys to distinguish metrics from different + data sources. + + Example: + >>> transform = SomeMetricTransform() + >>> transform.set_source("training_data") + >>> processed_sample = transform(sample) + >>> # Metrics are automatically added to sample["metrics"] + """ + + def __init__(self): + self.source = None + + def set_source(self, source: str): + """Set the source name for metrics (typically the dataset name).""" + self.source = source + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + """Transform a sample by adding metrics to it.""" + return sample + + +class DefaultDatasetMetricTransform(MetricTransform): + """ + Collects basic dataset processing metrics during data pipeline execution. + + Metrics collected: + - samples_processed: Total number of samples that have passed through this transform (SUM) + - tokens_processed: Total number of tokens processed across all samples (SUM) + - mean_seq_len: Average sequence length across samples (MEAN) + - max_seq_len: Maximum sequence length observed (MAX) + - min_seq_len: Minimum sequence length observed (MIN) + + Note: Token-related metrics are only collected if the sample contains a 'tokens' field. + Sequence length is measured as the number of tokens in each sample. + + Example: + >>> collector = DefaultDatasetMetricTransform() + >>> collector.set_source("training_data") + >>> sample = {"tokens": ["hello", "world"]} + >>> processed_sample = collector(sample) + >>> # Metrics are automatically added to processed_sample["metrics"] + """ + + def __call__(self, sample: dict[str, Any]) -> dict[str, Any]: + if "metrics" not in sample: + sample["metrics"] = [] + + source_name = self.source or "unnamed_ds" + + # Add samples_processed metric + sample["metrics"].append( + Metric( + key=f"dataset/{source_name}/samples_processed", + value=1, + reduction=Reduce.SUM, + ) + ) + + # Add token-based metrics if tokens are present + if "tokens" in sample: + token_count = len(sample.get("tokens", [])) + + sample["metrics"].extend( + [ + Metric( + key=f"dataset/{source_name}/tokens_processed", + value=token_count, + reduction=Reduce.SUM, + ), + Metric( + key=f"dataset/{source_name}/mean_seq_len", + value=token_count, + reduction=Reduce.MEAN, + ), + Metric( + key=f"dataset/{source_name}/max_seq_len", + value=token_count, + reduction=Reduce.MAX, + ), + Metric( + key=f"dataset/{source_name}/min_seq_len", + value=token_count, + reduction=Reduce.MIN, + ), + ] + ) + + return sample diff --git a/src/forge/observability/metric_actors.py b/src/forge/observability/metric_actors.py index 0ba2b8c73..d01617ba3 100644 --- a/src/forge/observability/metric_actors.py +++ b/src/forge/observability/metric_actors.py @@ -82,6 +82,7 @@ async def get_or_create_metric_logger( # Shutdown await mlogger.shutdown.call_one() """ + # Get or create the singleton global logger global _global_logger diff --git a/src/forge/observability/metrics.py b/src/forge/observability/metrics.py index c5ca7cf1f..35e12ab3f 100644 --- a/src/forge/observability/metrics.py +++ b/src/forge/observability/metrics.py @@ -544,7 +544,7 @@ def push(self, metric: Metric) -> None: " Metric logging backends (e.g. wandb) were not initialized." " This happens when you try to use `record_metric` before calling `init_backends`." " To disable this warning, please call in your main file:\n" - "`mlogger = await get_or_create_metric_logger()`\n" + "`mlogger = await get_or_create_metric_logger(process_name='Controller')`\n" "`await mlogger.init_backends.call_one(logging_config)`\n" "or set env variable `FORGE_DISABLE_METRICS=True`" ), diff --git a/tests/unit_tests/data/__init__.py b/tests/unit_tests/data/__init__.py deleted file mode 100644 index 2e41cd717..000000000 --- a/tests/unit_tests/data/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. diff --git a/tests/unit_tests/data/test_metrics_aggregator.py b/tests/unit_tests/data/test_metrics_aggregator.py deleted file mode 100644 index 5b847c92f..000000000 --- a/tests/unit_tests/data/test_metrics_aggregator.py +++ /dev/null @@ -1,456 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Tests for MetricsAggregator functionality. - -This module tests the metrics collection and aggregation system including: -- All aggregation types (SUM, MEAN, MAX, MIN, STATS, CATEGORICAL_COUNT) -- State management and checkpointing -- Multi-dataset metric namespacing -- Distributed metrics aggregation -- Metric consistency validation - -Uses synthetic metrics to verify correct aggregation behavior across scenarios. -""" - -import logging - -import pytest -import torch.distributed as dist - -from forge.data.dataset_metrics import AggregationType, Metric, MetricsAggregator -from torch.testing._internal.common_fsdp import FSDPTest - -from tests.test_utils import gpu_test - - -class TestMetricsAggregator: - """Tests for MetricsAggregator core functionality and edge cases.""" - - @pytest.mark.parametrize( - "agg_type,test_values,expected", - [ - (AggregationType.SUM, [1, 2, 3, 4], 10), - (AggregationType.MEAN, [10, 20, 30, 40], 25.0), - (AggregationType.MAX, [-5, 10, 3, 15], 15), - (AggregationType.MIN, [5, -2, 8, 1], -2), - ( - AggregationType.CATEGORICAL_COUNT, - ["A", "B", "A", "C", "A"], - {"A": 3, "B": 1, "C": 1}, - ), - ], - ) - def test_aggregation_types(self, agg_type, test_values, expected): - """Tests each AggregationType with representative data to verify correct computation. - - Covers aggregation types: - - SUM: Simple addition across values - - MEAN: Average computation with proper count tracking - - MAX/MIN: Extrema identification - - CATEGORICAL_COUNT: Category frequency counting - """ - aggregator = MetricsAggregator() - - metrics = [ - Metric(source="test", metric_name="metric", value=val, agg_type=agg_type) - for val in test_values - ] - aggregator.update(metrics) - - result = aggregator.get_metrics_for_logging(prefix="train") - - if agg_type == AggregationType.CATEGORICAL_COUNT: - for category, count in expected.items(): - assert result[f"train_test/metric_count_{category}"] == count - else: - assert result["train_test/metric"] == expected - - def test_stats_metrics(self): - """Tests that STATS aggregation computes statistics (mean, min, max, percentiles).""" - aggregator = MetricsAggregator() - values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - - metrics = [ - Metric("test", "dist_metric", val, AggregationType.STATS) for val in values - ] - aggregator.update(metrics) - - result = aggregator.get_metrics_for_logging(prefix="train") - - assert result["train_test/dist_metric_stat_mean"] == 5.5 - assert result["train_test/dist_metric_stat_min"] == 1 - assert result["train_test/dist_metric_stat_max"] == 10 - assert result["train_test/dist_metric_stat_p50"] == 5.5 - - def test_state_management(self): - """Test metrics aggregator state persistence and restoration for checkpointing scenarios.""" - # Create aggregator with mixed metric types to test state saving - aggregator1 = MetricsAggregator() - initial_metrics = [ - Metric("ds1", "counter", 10, AggregationType.SUM), - Metric("ds1", "average", 5.0, AggregationType.MEAN), - Metric("ds2", "categories", "X", AggregationType.CATEGORICAL_COUNT), - ] - aggregator1.update(initial_metrics) - - # Save state - state = aggregator1.state_dict() - - # Create new aggregator and restore state - aggregator2 = MetricsAggregator() - aggregator2.load_state_dict(state) - - # Both should have identical metrics - metrics1 = aggregator1.get_metrics_for_logging(prefix="train") - metrics2 = aggregator2.get_metrics_for_logging(prefix="train") - assert metrics1 == metrics2 - - # Continue updating both - should remain identical - additional_metrics = [ - Metric("ds1", "counter", 5, AggregationType.SUM), - Metric("ds1", "average", 15.0, AggregationType.MEAN), - ] - aggregator1.update(additional_metrics) - aggregator2.update(additional_metrics) - - final_metrics1 = aggregator1.get_metrics_for_logging(prefix="train") - final_metrics2 = aggregator2.get_metrics_for_logging(prefix="train") - assert final_metrics1 == final_metrics2 - - # Verify expected values - assert final_metrics1["train_ds1/counter"] == 15 # 10 + 5 - assert final_metrics1["train_ds1/average"] == 10.0 # (5 + 15) / 2 - - def test_multiple_datasets(self): - """Test that metrics from multiple datasets are correctly namespaced.""" - aggregator = MetricsAggregator() - - metrics = [ - Metric("dataset1", "samples", 100, AggregationType.SUM), - Metric("dataset2", "samples", 200, AggregationType.SUM), - Metric("dataset1", "tokens", 1000, AggregationType.SUM), - Metric("dataset2", "tokens", 2000, AggregationType.SUM), - ] - aggregator.update(metrics) - - result = aggregator.get_metrics_for_logging(prefix="train") - - assert result["train_dataset1/samples"] == 100 - assert result["train_dataset2/samples"] == 200 - assert result["train_dataset1/tokens"] == 1000 - assert result["train_dataset2/tokens"] == 2000 - - def test_empty_aggregator(self): - """Test that empty aggregator returns empty metrics.""" - aggregator = MetricsAggregator() - result = aggregator.get_metrics_for_logging(prefix="train") - assert result == {} - - def test_prefix_handling(self): - """Test that prefix is correctly applied to metric keys.""" - aggregator = MetricsAggregator() - metrics = [ - Metric("test_ds", "metric1", 42, AggregationType.SUM), - Metric("test_ds", "metric2", 84, AggregationType.SUM), - ] - aggregator.update(metrics) - - # Test with prefix - result_with_prefix = aggregator.get_metrics_for_logging(prefix="validation") - assert result_with_prefix["validation_test_ds/metric1"] == 42 - assert result_with_prefix["validation_test_ds/metric2"] == 84 - - # Test without prefix (uses default "data") - result_no_prefix = aggregator.get_metrics_for_logging() - assert result_no_prefix["data_test_ds/metric1"] == 42 - assert result_no_prefix["data_test_ds/metric2"] == 84 - - def test_metric_consistency_validation(self): - """Test that same metric name must use same aggregation type.""" - aggregator = MetricsAggregator() - - # First metric with SUM aggregation - metrics1 = [Metric("test", "my_metric", 10, AggregationType.SUM)] - aggregator.update(metrics1) - - # Try to use same metric name with different aggregation type - should fail - metrics2 = [Metric("test", "my_metric", 5.0, AggregationType.MEAN)] - with pytest.raises( - ValueError, match="is already registered with aggregation type sum" - ): - aggregator.update(metrics2) - - # Same metric name with same aggregation type should work - metrics3 = [Metric("test", "my_metric", 20, AggregationType.SUM)] - aggregator.update(metrics3) # Should not raise - - result = aggregator.get_metrics_for_logging(prefix="train") - assert result["train_test/my_metric"] == 30 # 10 + 20 - - def test_metric_consistency_across_datasets(self): - """Test that same metric name can use different aggregation types across different datasets.""" - aggregator = MetricsAggregator() - - # Same metric name but different datasets - should be allowed - metrics = [ - Metric("dataset1", "metric", 10, AggregationType.SUM), - Metric("dataset2", "metric", 5.0, AggregationType.MEAN), - ] - aggregator.update(metrics) # Should not raise - - result = aggregator.get_metrics_for_logging(prefix="train") - assert result["train_dataset1/metric"] == 10 - assert result["train_dataset2/metric"] == 5.0 - - def test_handler_generated_metric_validation(self): - """Test that handler-generated metrics are validated for consistency.""" - aggregator = MetricsAggregator() - - # Create a user-defined metric that will conflict with stats - user_metrics = [ - Metric("test", "dist_metric_stat_mean", 42, AggregationType.SUM) - ] - aggregator.update(user_metrics) - - # Now try to add a stats metric that will generate conflicting stat names - dist_metrics = [Metric("test", "dist_metric", 10, AggregationType.STATS)] - aggregator.update(dist_metrics) - - # This should fail when trying to get metrics for logging because the handler - # will try to create "dist_metric_stat_mean" which conflicts with the user metric - with pytest.raises( - ValueError, match="is already registered with aggregation type sum" - ): - aggregator.get_metrics_for_logging(prefix="train") - - def test_handler_replacement_warning(self, caplog): - """Test that replacing handlers in use generates a warning.""" - aggregator = MetricsAggregator() - - # Add a metric that uses SUM aggregation - metrics = [Metric("test", "sum_metric", 10, AggregationType.SUM)] - aggregator.update(metrics) - - # Replace the SUM handler - should generate warning - from forge.data.dataset_metrics import SumAggHandler - - with caplog.at_level(logging.WARNING): - aggregator.register_handler(AggregationType.SUM, SumAggHandler()) - - # Check that the expected warning was logged - assert len(caplog.records) == 1 - assert "Replacing handler for AggregationType.SUM" in caplog.records[0].message - - -class TestDistributedMetricsAggregator(FSDPTest): - """Distributed tests for MetricsAggregator using FSDPTest infrastructure.""" - - @property - def world_size(self) -> int: - return 2 - - @gpu_test(gpu_count=2) - def test_distributed_all_aggregation_types(self): - """ - Test that all aggregation types work correctly in distributed setting. - Each rank contributes different values to ensure proper reduction across ranks. - """ - aggregator = MetricsAggregator() - rank = dist.get_rank() - - # Each rank contributes different values to test cross-rank aggregation - base_value = (rank + 1) * 10 # rank 0: 10, rank 1: 20 - - metrics = [ - Metric("test", "sum_metric", base_value, AggregationType.SUM), - Metric("test", "mean_metric", base_value + 5, AggregationType.MEAN), - Metric("test", "max_metric", base_value * 10, AggregationType.MAX), - Metric("test", "min_metric", base_value // 2, AggregationType.MIN), - ] - - # STATS: Each rank adds 5 values for statistics - # rank 0: [0, 1, 2, 3, 4], rank 1: [10, 11, 12, 13, 14] - for i in range(5): - metrics.append( - Metric("test", "dist_metric", rank * 10 + i, AggregationType.STATS) - ) - - # CATEGORICAL_COUNT: Different categories per rank to test counting - # rank 0: 3 of cat_A, 2 of cat_B - # rank 1: 1 of cat_A, 4 of cat_C - if rank == 0: - metrics.extend( - [ - Metric( - "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT - ), - Metric( - "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT - ), - Metric( - "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT - ), - Metric( - "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT - ), - Metric( - "test", "cat_metric", "cat_B", AggregationType.CATEGORICAL_COUNT - ), - ] - ) - else: - metrics.extend( - [ - Metric( - "test", "cat_metric", "cat_A", AggregationType.CATEGORICAL_COUNT - ), - Metric( - "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT - ), - Metric( - "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT - ), - Metric( - "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT - ), - Metric( - "test", "cat_metric", "cat_C", AggregationType.CATEGORICAL_COUNT - ), - ] - ) - - # Update aggregator and get results - aggregator.update(metrics) - result = aggregator.get_metrics_for_logging(prefix="train") - - # Verify aggregation results across all ranks - # SUM: rank 0 adds 10, rank 1 adds 20 -> total 30 - # MEAN: rank 0 has 15, rank 1 has 25 -> avg 20 - # MAX: rank 0 has 100, rank 1 has 200 -> max 200 - # MIN: rank 0 has 5, rank 1 has 10 -> min 5 - assert result["train_test/sum_metric"] == 30 - assert result["train_test/mean_metric"] == 20 - assert result["train_test/max_metric"] == 200 - assert result["train_test/min_metric"] == 5 - - # STATS: Combined values [0,1,2,3,4,10,11,12,13,14] - # Mean should be average of local means: (2 + 12) / 2 = 7 - assert result["train_test/dist_metric_stat_mean"] == 7 - assert result["train_test/dist_metric_stat_min"] == 0 - assert result["train_test/dist_metric_stat_max"] == 14 - - # CATEGORICAL_COUNT: Total counts across ranks - # cat_A: 3(rank0) + 1(rank1) = 4, cat_B: 2(rank0) + 0(rank1) = 2, cat_C: 0(rank0) + 4(rank1) = 4 - assert result["train_test/cat_metric_count_cat_A"] == 4 - assert result["train_test/cat_metric_count_cat_B"] == 2 - assert result["train_test/cat_metric_count_cat_C"] == 4 - - @gpu_test(gpu_count=2) - def test_distributed_state_dict_resumption(self): - """ - Test that MetricsAggregator state_dict save/restore works correctly in distributed setting. - Verifies: - - State can be saved after partial updates across ranks - - State can be restored consistently across ranks - - Continued updates after restore produce identical results - - Distributed aggregation works correctly after restoration - """ - rank = dist.get_rank() - - # Phase 1: Create aggregator and add initial metrics - aggregator1 = MetricsAggregator() - - # Each rank contributes different initial values - base_value = rank * 100 # rank 0: 0, rank 1: 100 - - initial_metrics = [ - Metric("test", "sum_metric", base_value, AggregationType.SUM), - Metric("test", "mean_metric", base_value // 2, AggregationType.MEAN), - Metric("test", "max_metric", base_value * 2, AggregationType.MAX), - ] - - # Add some STATS values - each rank adds 3 values - for i in range(3): - initial_metrics.append( - Metric("test", "dist_metric", rank * 100 + i, AggregationType.STATS) - ) - - # Add CATEGORICAL_COUNT values - if rank == 0: - initial_metrics.extend( - [ - Metric( - "test", - "cat_metric", - "type_A", - AggregationType.CATEGORICAL_COUNT, - ), - Metric( - "test", - "cat_metric", - "type_A", - AggregationType.CATEGORICAL_COUNT, - ), - ] - ) - else: - initial_metrics.extend( - [ - Metric( - "test", - "cat_metric", - "type_B", - AggregationType.CATEGORICAL_COUNT, - ), - Metric( - "test", - "cat_metric", - "type_B", - AggregationType.CATEGORICAL_COUNT, - ), - Metric( - "test", - "cat_metric", - "type_B", - AggregationType.CATEGORICAL_COUNT, - ), - ] - ) - - aggregator1.update(initial_metrics) - - # Save state_dict after initial update - state_dict = aggregator1.state_dict() - - # Phase 2: Create new aggregator and restore from state_dict - aggregator2 = MetricsAggregator() - aggregator2.load_state_dict(state_dict) - - # Verify both aggregators produce identical results after restore - result1 = aggregator1.get_metrics_for_logging(prefix="checkpoint") - result2 = aggregator2.get_metrics_for_logging(prefix="checkpoint") - assert ( - result1 == result2 - ), f"Rank {rank}: Aggregators differ after state_dict restore" - - # Phase 3: Add more metrics to both aggregators - additional_metrics = [ - Metric("test", "sum_metric", rank * 1000, AggregationType.SUM), - Metric("test", "min_metric", rank * 1000, AggregationType.MIN), - ] - - # Update both aggregators with additional metrics - aggregator1.update(additional_metrics) - aggregator2.update(additional_metrics) - - # Phase 4: Verify final results are identical across both aggregators - final_result1 = aggregator1.get_metrics_for_logging(prefix="final") - final_result2 = aggregator2.get_metrics_for_logging(prefix="final") - assert ( - final_result1 == final_result2 - ), f"Rank {rank}: Final results differ after continued updates" diff --git a/tests/unit_tests/data/test_metrics_transform.py b/tests/unit_tests/data/test_metrics_transform.py deleted file mode 100644 index 078b511a8..000000000 --- a/tests/unit_tests/data/test_metrics_transform.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Tests cover: -- DefaultTrainingMetricTransform -- Basic metric generation (samples_seen, tokens_seen, seq_len) -- Dataset name validation and requirements -- Proper metric type assignment and aggregation configuration -""" - -import pytest - -from forge.data.dataset_metrics import AggregationType, DefaultTrainingMetricTransform - - -class TestDefaultTrainingMetricTransform: - """Tests for DefaultTrainingMetricTransform functionality.""" - - def test_source_not_set_raises_error(self): - """Test that the transform raises a RuntimeError if used before - `set_source` is called, ensuring that metrics are always - correctly attributed to a dataset.""" - transform = DefaultTrainingMetricTransform() - sample = {"tokens": [1, 2, 3]} - - with pytest.raises(RuntimeError, match="set_source"): - transform(sample) - - def test_basic_metrics_generation(self): - """Test that transform generates expected training metrics for input samples.""" - transform = DefaultTrainingMetricTransform() - # Set dataset name required for metric generation - transform.set_source("test_dataset") - - sample = {"tokens": [1, 2, 3, 4, 5]} - result = transform(sample) - - # Transform should preserve original sample data unchanged - assert result["tokens"] == [1, 2, 3, 4, 5] - - # Should generate exactly 3 metrics: samples_seen, tokens_seen, seq_len - assert "metrics" in result - metrics = result["metrics"] - assert len(metrics) == 3 - - # Verify each metric has correct properties and aggregation type - for metric in metrics: - if metric.metric_name == "samples_seen": - assert metric.source == "test_dataset" - assert metric.value == 1 - assert metric.agg_type == AggregationType.SUM - - elif metric.metric_name == "tokens_seen": - assert metric.source == "test_dataset" - assert metric.value == 5 - assert metric.agg_type == AggregationType.SUM - - elif metric.metric_name == "seq_len": - assert metric.source == "test_dataset" - assert metric.value == 5 - assert metric.agg_type == AggregationType.STATS diff --git a/tests/unit_tests/datasets/test_hf.py b/tests/unit_tests/datasets/test_hf.py index c1535c8b8..8298bf1a8 100644 --- a/tests/unit_tests/datasets/test_hf.py +++ b/tests/unit_tests/datasets/test_hf.py @@ -26,8 +26,8 @@ import pytest import torch.distributed as dist -from forge.data.dataset_metrics import DefaultTrainingMetricTransform, MetricsAggregator from forge.data.datasets import HfIterableDataset +from forge.data.metric_transform import DefaultDatasetMetricTransform from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader @@ -93,7 +93,7 @@ def _create_dataset( dataset_name=dataset_name, seed=SEED, shuffle_buffer_size=10 if shuffle else 0, - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=2, **kwargs, ) @@ -113,7 +113,7 @@ def test_default_dataset_name(self, small_dataset_file): split="train", # dataset_name not provided - should auto-generate seed=SEED, - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=4, ) @@ -131,7 +131,7 @@ def test_default_dataset_name(self, small_dataset_file): dataset_name="my_dataset", weight=custom_weight, seed=SEED, - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=4, ) @@ -149,17 +149,16 @@ def test_epoch_boundaries_and_checkpointing( the epoch metric is correct, and checkpointing works as expected. """ - # 1. Setup Dataloaders and Aggregators for original and resumed runs - def create_loader_and_aggregator(): + # 1. Setup Dataloaders for original and resumed runs + def create_loader(): dataset = dataset_factory(small_dataset_file, shuffle=False) loader = StatefulDataLoader( dataset, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics ) - aggregator = MetricsAggregator() - return loader, aggregator + return loader - loader1, aggregator1 = create_loader_and_aggregator() - loader2, aggregator2 = create_loader_and_aggregator() + loader1 = create_loader() + loader2 = create_loader() # 2. Calculate steps for the test run total_samples = int(SMALL_DATASET_SIZE * num_epochs) @@ -171,11 +170,9 @@ def create_loader_and_aggregator(): # 3. Generate checkpoint and resume result = generate_ckpt( loader1, - aggregator1, steps_before_checkpoint=steps_before_checkpoint, steps_after_checkpoint=steps_after_checkpoint, resume_dataloader=loader2, - resume_aggregator=aggregator2, ) # 4. Verify checkpointing and resumption @@ -184,9 +181,10 @@ def create_loader_and_aggregator(): assert ( orig_post_ids == resumed_ids ), "Resumed batches should be identical for deterministic run" + assert ( - result["final_metrics"] == result["resumed_metrics"] - ), "Final metrics should match" + result["post_checkpoint_metrics"] == result["resumed_metrics"] + ), "Resumed training should produce same metrics as original training" def test_shuffling_behavior(self, dataset_factory, small_dataset_file): """Tests that shuffling changes data order between epochs but preserves the set of samples.""" @@ -253,9 +251,7 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): for sample in first_epoch_samples: first_epoch_metrics.extend(sample["metrics"]) epoch_values = [ - metric.value - for metric in first_epoch_metrics - if metric.metric_name == "epoch" + metric.value for metric in first_epoch_metrics if "num_epochs" in metric.key ] assert all( epoch_value == 0 for epoch_value in epoch_values @@ -268,7 +264,7 @@ def test_epoch_tracking(self, dataset_factory, small_dataset_file): epoch_values = [ metric.value for metric in second_epoch_metrics - if metric.metric_name == "epoch" + if "num_epochs" in metric.key ] assert all( epoch_value == 1 for epoch_value in epoch_values @@ -291,30 +287,20 @@ def test_distributed_epoch_boundary_checkpointing(self): """ rank = dist.get_rank() - # Create shared temp directory (only rank 0 creates it) - if rank == 0: - temp_dir = tempfile.mkdtemp(prefix="epoch_test_") - else: - temp_dir = "" - - # Broadcast temp directory path to all ranks - temp_dir_list = [temp_dir] - dist.broadcast_object_list(temp_dir_list, src=0) - temp_dir = temp_dir_list[0] + # Each rank creates its own local temp dir and files + temp_dir = tempfile.mkdtemp(prefix=f"epoch_test_rank{rank}_") tmp_path = Path(temp_dir) try: medium_dataset_file = tmp_path / "medium_data.json" - # Only rank 0 creates the data file, all ranks read from it - if rank == 0: - create_test_json_file(medium_dataset_file, MEDIUM_DATASET_SIZE) - dist.barrier() # Wait for file creation + # Each rank creates its own file + create_test_json_file(medium_dataset_file, MEDIUM_DATASET_SIZE) # Test multiple epoch boundaries for num_epochs in [0.9, 1.0, 2.5]: - def create_loader_and_aggregator(): + def create_loader(): dataset = HfIterableDataset( path="json", data_files=str(medium_dataset_file), @@ -322,7 +308,7 @@ def create_loader_and_aggregator(): dataset_name="epoch_test", seed=SEED, shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=2, ) loader = StatefulDataLoader( @@ -331,10 +317,10 @@ def create_loader_and_aggregator(): collate_fn=collate_with_metrics, num_workers=0, ) - return loader, MetricsAggregator() + return loader - loader1, aggregator1 = create_loader_and_aggregator() - loader2, aggregator2 = create_loader_and_aggregator() + loader1 = create_loader() + loader2 = create_loader() # Calculate steps to reach desired epoch boundary samples_per_rank = MEDIUM_DATASET_SIZE // dist.get_world_size() @@ -352,11 +338,9 @@ def create_loader_and_aggregator(): result = generate_ckpt( loader1, - aggregator1, - steps_before, - steps_after, + steps_before_checkpoint=steps_before, + steps_after_checkpoint=steps_after, resume_dataloader=loader2, - resume_aggregator=aggregator2, ) # Verify deterministic resumption - critical for distributed training @@ -375,10 +359,8 @@ def create_loader_and_aggregator(): num_epochs - 1e-9 ) # -1e-9 so 1.0 epochs -> 0 assert ( - final_metrics["train_epoch_test/num_epochs"] == expected_epoch + final_metrics["dataset/epoch_test/num_epochs"] == expected_epoch ), f"Epoch count incorrect for {num_epochs} epochs test scenario" finally: - # Clean up temp directory (only rank 0) - if rank == 0: - shutil.rmtree(temp_dir) + shutil.rmtree(temp_dir) diff --git a/tests/unit_tests/datasets/test_interleaved.py b/tests/unit_tests/datasets/test_interleaved.py index 0073b905e..f5d55042c 100644 --- a/tests/unit_tests/datasets/test_interleaved.py +++ b/tests/unit_tests/datasets/test_interleaved.py @@ -28,9 +28,9 @@ import torch import torch.distributed as dist - -from forge.data.dataset_metrics import DefaultTrainingMetricTransform, MetricsAggregator from forge.data.datasets import HfIterableDataset, InterleavedDataset + +from forge.data.metric_transform import DefaultDatasetMetricTransform from torch.testing._internal.common_fsdp import FSDPTest from torchdata.stateful_dataloader import StatefulDataLoader @@ -114,7 +114,7 @@ def _create_dataset( dataset_name=dataset_name, seed=SEED, shuffle_buffer_size=10 if shuffle else 0, - metric_transform=DefaultTrainingMetricTransform(), + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=2, **kwargs, ) @@ -299,37 +299,47 @@ def test_metrics_aggregation( [child_interleaved, ds3], seed=SEED, dataset_name="parent" ) - aggregator = MetricsAggregator() + # Collect metrics + collected_metrics = [] # Process some samples total_samples = 300 for sample in islice(iter(parent_interleaved), total_samples): - aggregator.update(sample["metrics"]) - - metrics = aggregator.get_metrics_for_logging(prefix="train") - - # Should have metrics from all three datasets, with flat keys - assert "train_ds1/samples_seen" in metrics - assert "train_ds2/samples_seen" in metrics - assert "train_ds3/samples_seen" in metrics + if "metrics" in sample: + collected_metrics.extend(sample["metrics"]) + + # Count metrics by dataset name + ds1_samples_processed = sum( + 1 + for m in collected_metrics + if hasattr(m, "key") and "dataset/ds1/samples_processed" in m.key + ) + ds2_samples_processed = sum( + 1 + for m in collected_metrics + if hasattr(m, "key") and "dataset/ds2/samples_processed" in m.key + ) + ds3_samples_processed = sum( + 1 + for m in collected_metrics + if hasattr(m, "key") and "dataset/ds3/samples_processed" in m.key + ) # All datasets should have contributed samples - assert metrics["train_ds1/samples_seen"] > 0 - assert metrics["train_ds2/samples_seen"] > 0 - assert metrics["train_ds3/samples_seen"] > 0 + assert ds1_samples_processed > 0, "ds1 should have contributed samples" + assert ds2_samples_processed > 0, "ds2 should have contributed samples" + assert ds3_samples_processed > 0, "ds3 should have contributed samples" # Total samples should equal what we processed calculated_total_samples = ( - metrics["train_ds1/samples_seen"] - + metrics["train_ds2/samples_seen"] - + metrics["train_ds3/samples_seen"] + ds1_samples_processed + ds2_samples_processed + ds3_samples_processed ) assert calculated_total_samples == total_samples # Test that ratios are approximately correct based on nested weighting - ds1_ratio = metrics["train_ds1/samples_seen"] / total_samples - ds2_ratio = metrics["train_ds2/samples_seen"] / total_samples - ds3_ratio = metrics["train_ds3/samples_seen"] / total_samples + ds1_ratio = ds1_samples_processed / total_samples + ds2_ratio = ds2_samples_processed / total_samples + ds3_ratio = ds3_samples_processed / total_samples # Expected ratios based on nested weighting: # Inner weights: ds1=0.2, ds2=0.8 -> inner total=1.0 @@ -377,32 +387,30 @@ def create_interleaved(): loader1 = StatefulDataLoader( interleaved1, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics ) - aggregator1 = MetricsAggregator() # Resumed run interleaved2 = create_interleaved() loader2 = StatefulDataLoader( interleaved2, batch_size=BATCH_SIZE, collate_fn=collate_with_metrics ) - aggregator2 = MetricsAggregator() result = generate_ckpt( loader1, - aggregator1, steps_before_checkpoint=10, steps_after_checkpoint=20, resume_dataloader=loader2, - resume_aggregator=aggregator2, ) + # Verify checkpointing and resumption work correctly + # After loading a checkpoint, training should continue identically orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]] resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]] assert ( orig_post_ids == resumed_ids ), "Resumed batches should be identical for deterministic run" assert ( - result["final_metrics"] == result["resumed_metrics"] - ), "Final metrics should match" + result["post_checkpoint_metrics"] == result["resumed_metrics"] + ), "Resumed training should produce same metrics as original training" # Test sampling log functionality # Check that sampling log contains tuples of (iteration_count, dataset_name) @@ -476,43 +484,36 @@ def test_distributed_interleaved_checkpointing(self): """ rank = dist.get_rank() - # Create shared temp directory (only rank 0 creates it) - if rank == 0: - temp_dir = tempfile.mkdtemp(prefix="interleaved_test_") - else: - temp_dir = None - - # Broadcast temp directory to all ranks - temp_dir_list = [temp_dir] if temp_dir is not None else [""] - dist.broadcast_object_list(temp_dir_list, src=0) - temp_dir = temp_dir_list[0] + # Each rank creates its own local temp dir and files (no broadcast/barrier needed for creation) + temp_dir = tempfile.mkdtemp(prefix=f"interleaved_test_rank{rank}_") tmp_path = Path(temp_dir) try: - - def create_dataset(): - file1 = tmp_path / "ds1.json" - file2 = tmp_path / "ds2.json" - file3 = tmp_path / "ds3.json" - - # Only rank 0 creates the data files - if rank == 0: - create_test_json_file(file1, SMALL_DATASET_SIZE) # IDs 0-22 - create_test_json_file( - file2, MEDIUM_DATASET_SIZE, offset=100 - ) # IDs 100-134 - create_test_json_file( - file3, LARGE_DATASET_SIZE, offset=1000 - ) # IDs 1000-1046 - dist.barrier() # Wait for file creation - + # ============================================ + # SETUP: Each rank creates its own test files + # ============================================ + file1 = tmp_path / "ds1.json" + file2 = tmp_path / "ds2.json" + file3 = tmp_path / "ds3.json" + + create_test_json_file(file1, SMALL_DATASET_SIZE, offset=0) + create_test_json_file(file2, MEDIUM_DATASET_SIZE, offset=100) + create_test_json_file(file3, LARGE_DATASET_SIZE, offset=1000) + + # No barrier needed since files are local to each rank + + # ============================================ + # TEST LOGIC: Functions that use the files + # ============================================ + def create_dataset() -> InterleavedDataset: + """Create interleaved dataset from local files.""" ds1 = HfIterableDataset( path="json", data_files=str(file1), split="train", dataset_name="ds1", - shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultTrainingMetricTransform(), + shuffle_buffer_size=0, + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=2, weight=0.3, ) @@ -521,8 +522,8 @@ def create_dataset(): data_files=str(file2), split="train", dataset_name="ds2", - shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultTrainingMetricTransform(), + shuffle_buffer_size=0, + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=2, weight=0.7, ) @@ -531,8 +532,8 @@ def create_dataset(): data_files=str(file3), split="train", dataset_name="ds3", - shuffle_buffer_size=0, # No shuffle for determinism - metric_transform=DefaultTrainingMetricTransform(), + shuffle_buffer_size=0, + metric_transform=DefaultDatasetMetricTransform(), num_shards_per_rank=2, weight=1.0, ) @@ -552,19 +553,17 @@ def create_dataloader(dataset): num_workers=0, # Avoid multiprocessing in distributed tests collate_fn=collate_with_metrics, ) - return loader, MetricsAggregator() + return loader # Run checkpointing test with small number of steps - loader1, aggregator1 = create_dataloader(create_dataset()) - loader2, aggregator2 = create_dataloader(create_dataset()) + loader1 = create_dataloader(create_dataset()) + loader2 = create_dataloader(create_dataset()) result = generate_ckpt( loader1, - aggregator1, - 3, - 3, # 3 steps before, 3 steps after checkpoint + steps_before_checkpoint=3, + steps_after_checkpoint=3, resume_dataloader=loader2, - resume_aggregator=aggregator2, ) # Verify deterministic resumption @@ -577,8 +576,8 @@ def create_dataloader(dataset): f"This indicates sampling state is not properly preserved." ) assert ( - result["final_metrics"] == result["resumed_metrics"] - ), "Final metrics don't match resumed metrics - aggregator state issue" + result["post_checkpoint_metrics"] == result["resumed_metrics"] + ), "Resumed training should produce same metrics as original training" # Verify sampling ratio is approximately maintained for nested structure all_ids = [] @@ -621,6 +620,5 @@ def create_dataloader(dataset): ), f"ds3 ratio {ds3_ratio:.2f} should be ~{expected_ds3_ratio}" finally: - # Clean up temp directory (only rank 0) - if rank == 0: - shutil.rmtree(temp_dir) + # Each rank cleans its own temp dir + shutil.rmtree(temp_dir) diff --git a/tests/unit_tests/datasets/test_iterable_utils.py b/tests/unit_tests/datasets/test_iterable_utils.py index cdeced7c7..0c6d26fe3 100644 --- a/tests/unit_tests/datasets/test_iterable_utils.py +++ b/tests/unit_tests/datasets/test_iterable_utils.py @@ -7,92 +7,91 @@ from typing import Any, Optional import torch -from forge.data.dataset_metrics import MetricsAggregator - from torch.utils.data import DataLoader -def collate_with_metrics(batch: list[dict[str, Any]]) -> dict[str, Any]: - """Simple collate that extracts metrics and pads tokens.""" - all_metrics = [] - clean_batch = [] +def collate_with_metrics(batch): + """ + Simple collate function that preserves metrics for validation. + Collects metrics from all samples in the batch and aggregates them. + + Uses a simple collation that doesn't enforce same sizes for lists/tokens. + """ + # Collect metrics from all samples + batch_metrics = [] for sample in batch: if "metrics" in sample: - all_metrics.extend(sample.pop("metrics")) - clean_batch.append(sample) - - if not clean_batch: - return {"metrics": all_metrics} - - # Simple padding for tokens - ids = torch.tensor([item["id"] for item in clean_batch]) - tokens = torch.nn.utils.rnn.pad_sequence( - [torch.tensor(item["tokens"]) for item in clean_batch], - batch_first=True, - padding_value=-1, # Use -1 for padding to distinguish from valid IDs - ) - collated = { - "id": ids, - "tokens": tokens, - } - - # Add text field for non-tensor data - if "text" in clean_batch[0]: - collated["text"] = [item["text"] for item in clean_batch] + batch_metrics.extend(sample.pop("metrics")) + + # Simple collation that handles variable-length sequences + collated = {} + if batch: + for key in batch[0].keys(): + values = [sample[key] for sample in batch] + if key == "tokens" or key == "labels": + # Keep as list of lists for variable length sequences + collated[key] = values + else: + # Use default collation for scalars + collated[key] = torch.utils.data.default_collate(values) + + # Add batch-level metrics key for downstream processing + if batch_metrics: + collated["metrics"] = batch_metrics - collated["metrics"] = all_metrics return collated def generate_ckpt( dataloader: DataLoader, - aggregator: MetricsAggregator, steps_before_checkpoint: int, steps_after_checkpoint: int, resume_dataloader: Optional[DataLoader] = None, - resume_aggregator: Optional[MetricsAggregator] = None, ) -> dict[str, Any]: """ Generates a checkpoint by running through data and saving checkpoint mid-stream. - Optionally, a second dataloader and aggregator can be given to resume from ckpt + Optionally, a second dataloader can be given to resume from checkpoint and run steps_after_checkpoint to match the first one. + Collects and aggregates metrics for test validation purposes. + Args: dataloader (DataLoader): The dataloader to test - aggregator (MetricsAggregator): The metrics aggregator to use steps_before_checkpoint (int): Number of steps to run before saving checkpoint steps_after_checkpoint (int): Number of steps to run after checkpoint resume_dataloader (Optional[DataLoader]): Optional new dataloader to test resuming. If None, returns empty resumed_batches. - resume_aggregator (Optional[MetricsAggregator]): Optional new aggregator to test resuming. - If None, returns empty resumed_metrics. Returns: - dict[str, Any]: Dict with batches/metrics from both pre and post checkpoint runs. + dict[str, Any]: Dict with batches and aggregated metrics for validation. """ iterator = iter(dataloader) - # Collect batches before and after checkpoint + # Collect batches and metrics before and after checkpoint batches = [] + all_metrics = [] # All metrics collected during the run + checkpoint_metrics = [] # Metrics collected only up to checkpoint checkpoint_state = None - metrics_at_checkpoint = {} total_steps = steps_before_checkpoint + steps_after_checkpoint for idx, batch in enumerate(iterator): batches.append(batch) - # Process metrics + # Collect metrics for test validation if "metrics" in batch: - aggregator.update(batch.pop("metrics")) + batch_metrics = batch.pop("metrics") + all_metrics.extend(batch_metrics) + + # If we haven't reached checkpoint yet, also add to checkpoint metrics + if idx < steps_before_checkpoint: + checkpoint_metrics.extend(batch_metrics) # Save checkpoint state after steps_before_checkpoint if idx == steps_before_checkpoint - 1: # -1 because idx is 0-based checkpoint_state = { "loader": dataloader.state_dict(), - "aggregator": aggregator.state_dict(), } - metrics_at_checkpoint = aggregator.get_metrics_for_logging(prefix="train") # Stop after total steps if idx == total_steps - 1: @@ -102,43 +101,56 @@ def generate_ckpt( pre_checkpoint_batches = batches[:steps_before_checkpoint] post_checkpoint_batches = batches[steps_before_checkpoint:] - # Resume with new instances if provided + # Compute metrics for post-checkpoint batches only + post_checkpoint_metrics = all_metrics[len(checkpoint_metrics) :] + + # Resume with new instance if provided resumed_batches = [] - resumed_metrics = {} - - if ( - resume_dataloader is not None - and resume_aggregator is not None - and checkpoint_state is not None - ): - # Test resuming with new instances + resumed_metrics = [] + + if resume_dataloader is not None and checkpoint_state is not None: + # Test resuming with new instance resume_dataloader.load_state_dict(checkpoint_state["loader"]) - resume_aggregator.load_state_dict(checkpoint_state["aggregator"]) resume_iterator = iter(resume_dataloader) # Collect only the post-checkpoint batches when resuming for idx, batch in enumerate(resume_iterator): resumed_batches.append(batch) - # Process metrics + # Collect metrics from resumed batches if "metrics" in batch: - resume_aggregator.update(batch.pop("metrics")) + batch_metrics = batch.pop("metrics") + resumed_metrics.extend(batch_metrics) # Stop after steps_after_checkpoint if idx == steps_after_checkpoint - 1: break - resumed_metrics = resume_aggregator.get_metrics_for_logging(prefix="train") - return { # Original run "pre_checkpoint_batches": pre_checkpoint_batches, "post_checkpoint_batches": post_checkpoint_batches, - "metrics_at_checkpoint": metrics_at_checkpoint, - "final_metrics": aggregator.get_metrics_for_logging(prefix="train"), + "metrics_at_checkpoint": aggregate_metrics(checkpoint_metrics), + "post_checkpoint_metrics": aggregate_metrics(post_checkpoint_metrics), + "final_metrics": aggregate_metrics(all_metrics), # Resumed run "resumed_batches": resumed_batches, - "resumed_metrics": resumed_metrics, + "resumed_metrics": aggregate_metrics(resumed_metrics), # Internal state for loading - only if someone needs to manually load "_checkpoint_state": checkpoint_state, } + + +def aggregate_metrics(metrics_list: list) -> dict[str, Any]: + if not metrics_list: + return {} + + accumulators = {} + + for metric in metrics_list: + key = metric.key + if key not in accumulators: + accumulators[key] = metric.reduction.accumulator_class(metric.reduction) + accumulators[key].append(metric.value) + + return {key: acc.get_value() for key, acc in accumulators.items()} diff --git a/tests/unit_tests/datasets/test_packed.py b/tests/unit_tests/datasets/test_packed.py index 352fbf703..56cd5ff02 100644 --- a/tests/unit_tests/datasets/test_packed.py +++ b/tests/unit_tests/datasets/test_packed.py @@ -14,7 +14,6 @@ import torch from forge.data.collate import collate_packed -from forge.data.dataset_metrics import MetricsAggregator from forge.data.datasets import HfIterableDataset from forge.data.datasets.packed import ( _SUPPORTS_FLEX_ATTENTION, @@ -914,7 +913,7 @@ def test_checkpoint_and_resume(self, dataset_factory): batch_size = 1 # Setup dataset factory - def create_loader_and_aggregator(): + def create_loader(): dataset = dataset_factory(samples) packer = TextPacker(padding_idx=999, ignore_idx=-100) packed_dataset = PackedDataset( @@ -931,11 +930,10 @@ def create_loader_and_aggregator(): loader = StatefulDataLoader( packed_dataset, batch_size=batch_size, collate_fn=collate_fn ) - aggregator = MetricsAggregator() - return loader, aggregator + return loader - loader1, aggregator1 = create_loader_and_aggregator() - loader2, aggregator2 = create_loader_and_aggregator() + loader1 = create_loader() + loader2 = create_loader() steps_before_checkpoint = 2 steps_after_checkpoint = 2 @@ -943,11 +941,9 @@ def create_loader_and_aggregator(): # Generate checkpoint and resume result = generate_ckpt( loader1, - aggregator1, steps_before_checkpoint=steps_before_checkpoint, steps_after_checkpoint=steps_after_checkpoint, resume_dataloader=loader2, - resume_aggregator=aggregator2, ) # Verify that checkpointing and resumption work diff --git a/tests/unit_tests/observability/test_perf_tracker.py b/tests/unit_tests/observability/test_perf_tracker.py index 6faef3956..1beeb647a 100644 --- a/tests/unit_tests/observability/test_perf_tracker.py +++ b/tests/unit_tests/observability/test_perf_tracker.py @@ -55,7 +55,7 @@ def assert_metrics_dict_matches(calls, expected_metrics): assert metric_name in actual_metrics, f"Missing metric: {metric_name}" actual_val = actual_metrics[metric_name] assert actual_val == pytest.approx( - expected_val, rel=0.1 # 10% relative tolerance for timing tests + expected_val, rel=0.2 # 20% relative tolerance for timing tests ), f"Expected {metric_name}={expected_val}, got {actual_val}"