Skip to content

Commit 2bd3b35

Browse files
author
Felipe Mello
committed
docs and naming
1 parent 57877da commit 2bd3b35

File tree

8 files changed

+147
-129
lines changed

8 files changed

+147
-129
lines changed

src/forge/data/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
from .collate import collate_packed
7+
from .metric_transform import DefaultDatasetMetricTransform, MetricTransform
78
from .utils import CROSS_ENTROPY_IGNORE_IDX
89

9-
__all__ = ["collate_packed", "CROSS_ENTROPY_IGNORE_IDX"]
10+
__all__ = [
11+
"collate_packed",
12+
"CROSS_ENTROPY_IGNORE_IDX",
13+
"MetricTransform",
14+
"DefaultDatasetMetricTransform",
15+
]

src/forge/data/dataset_metrics/__init__.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

src/forge/data/dataset_metrics/metric_transform.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

src/forge/data/datasets/hf_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from datasets import load_dataset
1313
from datasets.distributed import split_dataset_by_node
1414

15-
from forge.data.dataset_metrics import DefaultTrainingMetricTransform, MetricTransform
15+
from forge.data.metric_transform import DefaultDatasetMetricTransform, MetricTransform
1616
from forge.interfaces import Transform
1717
from forge.observability.metrics import Metric, Reduce
1818

@@ -82,7 +82,7 @@ def __init__(
8282
self._weight = weight if weight is not None else 1.0
8383

8484
# Create default transform if not provided
85-
self._metric_transform = metric_transform or DefaultTrainingMetricTransform()
85+
self._metric_transform = metric_transform or DefaultDatasetMetricTransform()
8686

8787
# Auto-generate dataset name if not provided
8888
if dataset_name is None:

src/forge/data/datasets/sft_dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010

1111
from forge.data import CROSS_ENTROPY_IGNORE_IDX
12-
from forge.data.dataset_metrics import DefaultTrainingMetricTransform
12+
from forge.data.metric_transform import DefaultDatasetMetricTransform
1313
from forge.data.utils import mask_messages, TuneMessage
1414
from forge.interfaces import Transform
1515

@@ -200,7 +200,7 @@ def sft_iterable_dataset(
200200
message_transform=message_transform,
201201
model_transform=model_transform,
202202
output_transform=output_transform,
203-
metric_transform=DefaultTrainingMetricTransform(),
203+
metric_transform=DefaultDatasetMetricTransform(),
204204
shuffle_buffer_size=shuffle_buffer_size,
205205
weight=weight,
206206
seed=seed,

src/forge/data/metric_transform.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from typing import Any
8+
9+
from forge.interfaces import Transform
10+
from forge.observability.metrics import Metric, Reduce
11+
12+
13+
class MetricTransform(Transform):
14+
"""
15+
Base class for transforms that collect observability metrics from dataset samples.
16+
17+
This class provides a foundation for implementing dataset-level metric collection
18+
during data processing pipelines. Subclasses should override the __call__ method
19+
to add specific metrics to each sample that passes through the transform.
20+
21+
Metrics are collected as `forge.observability.metrics.Metric` objects and made available
22+
in batch["metrics"].
23+
24+
Attributes:
25+
source (str, optional): The source name for metrics, typically the dataset name.
26+
This is used as a prefix in metric keys to distinguish metrics from different
27+
data sources.
28+
29+
Example:
30+
>>> transform = SomeMetricTransform()
31+
>>> transform.set_source("training_data")
32+
>>> processed_sample = transform(sample)
33+
>>> # Metrics are automatically added to sample["metrics"]
34+
"""
35+
36+
def __init__(self):
37+
self.source = None
38+
39+
def set_source(self, source: str):
40+
"""Set the source name for metrics (typically the dataset name)."""
41+
self.source = source
42+
43+
def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
44+
"""Transform a sample by adding metrics to it."""
45+
return sample
46+
47+
48+
class DefaultDatasetMetricTransform(MetricTransform):
49+
"""
50+
Collects basic dataset processing metrics during data pipeline execution.
51+
52+
Metrics collected:
53+
- samples_processed: Total number of samples that have passed through this transform (SUM)
54+
- tokens_processed: Total number of tokens processed across all samples (SUM)
55+
- mean_seq_len: Average sequence length across samples (MEAN)
56+
- max_seq_len: Maximum sequence length observed (MAX)
57+
- min_seq_len: Minimum sequence length observed (MIN)
58+
59+
Note: Token-related metrics are only collected if the sample contains a 'tokens' field.
60+
Sequence length is measured as the number of tokens in each sample.
61+
62+
Example:
63+
>>> collector = DefaultDatasetMetricTransform()
64+
>>> collector.set_source("training_data")
65+
>>> sample = {"tokens": ["hello", "world"]}
66+
>>> processed_sample = collector(sample)
67+
>>> # Metrics are automatically added to processed_sample["metrics"]
68+
"""
69+
70+
def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
71+
if "metrics" not in sample:
72+
sample["metrics"] = []
73+
74+
source_name = self.source or "dataset"
75+
76+
# Add samples_processed metric
77+
sample["metrics"].append(
78+
Metric(
79+
key=f"dataset/{source_name}/samples_processed",
80+
value=1,
81+
reduction=Reduce.SUM,
82+
)
83+
)
84+
85+
# Add token-based metrics if tokens are present
86+
if "tokens" in sample:
87+
token_count = len(sample.get("tokens", []))
88+
89+
sample["metrics"].extend(
90+
[
91+
Metric(
92+
key=f"dataset/{source_name}/tokens_processed",
93+
value=token_count,
94+
reduction=Reduce.SUM,
95+
),
96+
Metric(
97+
key=f"dataset/{source_name}/mean_seq_len",
98+
value=token_count,
99+
reduction=Reduce.MEAN,
100+
),
101+
Metric(
102+
key=f"dataset/{source_name}/max_seq_len",
103+
value=token_count,
104+
reduction=Reduce.MAX,
105+
),
106+
Metric(
107+
key=f"dataset/{source_name}/min_seq_len",
108+
value=token_count,
109+
reduction=Reduce.MIN,
110+
),
111+
]
112+
)
113+
114+
return sample

tests/unit_tests/datasets/test_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525

2626
import pytest
2727
import torch.distributed as dist
28-
from forge.data.dataset_metrics import DefaultTrainingMetricTransform
2928

3029
from forge.data.datasets import HfIterableDataset
30+
from forge.data.metric_transform import DefaultDatasetMetricTransform
3131
from torch.testing._internal.common_fsdp import FSDPTest
3232

3333
from torchdata.stateful_dataloader import StatefulDataLoader
@@ -93,7 +93,7 @@ def _create_dataset(
9393
dataset_name=dataset_name,
9494
seed=SEED,
9595
shuffle_buffer_size=10 if shuffle else 0,
96-
metric_transform=DefaultTrainingMetricTransform(),
96+
metric_transform=DefaultDatasetMetricTransform(),
9797
num_shards_per_rank=2,
9898
**kwargs,
9999
)

tests/unit_tests/datasets/test_interleaved.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828

2929
import torch
3030
import torch.distributed as dist
31-
32-
from forge.data.dataset_metrics import DefaultTrainingMetricTransform
3331
from forge.data.datasets import HfIterableDataset, InterleavedDataset
32+
33+
from forge.data.metric_transform import DefaultDatasetMetricTransform
3434
from torch.testing._internal.common_fsdp import FSDPTest
3535
from torchdata.stateful_dataloader import StatefulDataLoader
3636

@@ -114,7 +114,7 @@ def _create_dataset(
114114
dataset_name=dataset_name,
115115
seed=SEED,
116116
shuffle_buffer_size=10 if shuffle else 0,
117-
metric_transform=DefaultTrainingMetricTransform(),
117+
metric_transform=DefaultDatasetMetricTransform(),
118118
num_shards_per_rank=2,
119119
**kwargs,
120120
)
@@ -308,38 +308,38 @@ def test_metrics_aggregation(
308308
if "metrics" in sample:
309309
collected_metrics.extend(sample["metrics"])
310310

311-
# Count metrics by dataset name
312-
ds1_samples_seen = sum(
311+
# Count metrics by dataset name (using new metric key)
312+
ds1_samples_processed = sum(
313313
1
314314
for m in collected_metrics
315-
if hasattr(m, "key") and "dataset/ds1/samples_seen" in m.key
315+
if hasattr(m, "key") and "dataset/ds1/samples_processed" in m.key
316316
)
317-
ds2_samples_seen = sum(
317+
ds2_samples_processed = sum(
318318
1
319319
for m in collected_metrics
320-
if hasattr(m, "key") and "dataset/ds2/samples_seen" in m.key
320+
if hasattr(m, "key") and "dataset/ds2/samples_processed" in m.key
321321
)
322-
ds3_samples_seen = sum(
322+
ds3_samples_processed = sum(
323323
1
324324
for m in collected_metrics
325-
if hasattr(m, "key") and "dataset/ds3/samples_seen" in m.key
325+
if hasattr(m, "key") and "dataset/ds3/samples_processed" in m.key
326326
)
327327

328328
# All datasets should have contributed samples
329-
assert ds1_samples_seen > 0, "ds1 should have contributed samples"
330-
assert ds2_samples_seen > 0, "ds2 should have contributed samples"
331-
assert ds3_samples_seen > 0, "ds3 should have contributed samples"
329+
assert ds1_samples_processed > 0, "ds1 should have contributed samples"
330+
assert ds2_samples_processed > 0, "ds2 should have contributed samples"
331+
assert ds3_samples_processed > 0, "ds3 should have contributed samples"
332332

333333
# Total samples should equal what we processed
334334
calculated_total_samples = (
335-
ds1_samples_seen + ds2_samples_seen + ds3_samples_seen
335+
ds1_samples_processed + ds2_samples_processed + ds3_samples_processed
336336
)
337337
assert calculated_total_samples == total_samples
338338

339339
# Test that ratios are approximately correct based on nested weighting
340-
ds1_ratio = ds1_samples_seen / total_samples
341-
ds2_ratio = ds2_samples_seen / total_samples
342-
ds3_ratio = ds3_samples_seen / total_samples
340+
ds1_ratio = ds1_samples_processed / total_samples
341+
ds2_ratio = ds2_samples_processed / total_samples
342+
ds3_ratio = ds3_samples_processed / total_samples
343343

344344
# Expected ratios based on nested weighting:
345345
# Inner weights: ds1=0.2, ds2=0.8 -> inner total=1.0
@@ -518,7 +518,7 @@ def create_dataset():
518518
split="train",
519519
dataset_name="ds1",
520520
shuffle_buffer_size=0, # No shuffle for determinism
521-
metric_transform=DefaultTrainingMetricTransform(),
521+
metric_transform=DefaultDatasetMetricTransform(),
522522
num_shards_per_rank=2,
523523
weight=0.3,
524524
)
@@ -528,7 +528,7 @@ def create_dataset():
528528
split="train",
529529
dataset_name="ds2",
530530
shuffle_buffer_size=0, # No shuffle for determinism
531-
metric_transform=DefaultTrainingMetricTransform(),
531+
metric_transform=DefaultDatasetMetricTransform(),
532532
num_shards_per_rank=2,
533533
weight=0.7,
534534
)
@@ -538,7 +538,7 @@ def create_dataset():
538538
split="train",
539539
dataset_name="ds3",
540540
shuffle_buffer_size=0, # No shuffle for determinism
541-
metric_transform=DefaultTrainingMetricTransform(),
541+
metric_transform=DefaultDatasetMetricTransform(),
542542
num_shards_per_rank=2,
543543
weight=1.0,
544544
)

0 commit comments

Comments
 (0)