Skip to content

Commit 7838dc4

Browse files
author
Felipe Mello
committed
fix tests
1 parent 710703e commit 7838dc4

File tree

5 files changed

+29
-21
lines changed

5 files changed

+29
-21
lines changed

apps/sft/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@ def __init__(self, config: DictConfig):
7878

7979
self.current_step = 0
8080
self.num_training_steps = job_config.training.steps
81-
self.metric_logger = None # TODO: fix this
8281
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
8382
self._rank = current_rank().rank
8483
self._size = math.prod(current_size().values())

src/forge/data/metric_transform.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,10 @@
66

77
from typing import Any
88

9-
from forge.interfaces import Transform
109
from forge.observability.metrics import Metric, Reduce
1110

1211

13-
class MetricTransform(Transform):
12+
class MetricTransform:
1413
"""
1514
Base class for transforms that collect observability metrics from dataset samples.
1615
@@ -71,7 +70,7 @@ def __call__(self, sample: dict[str, Any]) -> dict[str, Any]:
7170
if "metrics" not in sample:
7271
sample["metrics"] = []
7372

74-
source_name = self.source or "dataset"
73+
source_name = self.source or "unnamed_ds"
7574

7675
# Add samples_processed metric
7776
sample["metrics"].append(

tests/unit_tests/datasets/test_hf.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -181,9 +181,10 @@ def create_loader():
181181
assert (
182182
orig_post_ids == resumed_ids
183183
), "Resumed batches should be identical for deterministic run"
184+
184185
assert (
185-
result["final_metrics"] == result["resumed_metrics"]
186-
), "Final metrics should match"
186+
result["post_checkpoint_metrics"] == result["resumed_metrics"]
187+
), "Resumed training should produce same metrics as original training"
187188

188189
def test_shuffling_behavior(self, dataset_factory, small_dataset_file):
189190
"""Tests that shuffling changes data order between epochs but preserves the set of samples."""

tests/unit_tests/datasets/test_interleaved.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -401,14 +401,16 @@ def create_interleaved():
401401
resume_dataloader=loader2,
402402
)
403403

404+
# Verify checkpointing and resumption work correctly
405+
# After loading a checkpoint, training should continue identically
404406
orig_post_ids = [b["id"].tolist() for b in result["post_checkpoint_batches"]]
405407
resumed_ids = [b["id"].tolist() for b in result["resumed_batches"]]
406408
assert (
407409
orig_post_ids == resumed_ids
408410
), "Resumed batches should be identical for deterministic run"
409411
assert (
410-
result["final_metrics"] == result["resumed_metrics"]
411-
), "Final metrics should match"
412+
result["post_checkpoint_metrics"] == result["resumed_metrics"]
413+
), "Resumed training should produce same metrics as original training"
412414

413415
# Test sampling log functionality
414416
# Check that sampling log contains tuples of (iteration_count, dataset_name)
@@ -581,8 +583,8 @@ def create_dataloader(dataset):
581583
f"This indicates sampling state is not properly preserved."
582584
)
583585
assert (
584-
result["final_metrics"] == result["resumed_metrics"]
585-
), "Final metrics don't match resumed metrics - aggregator state issue"
586+
result["post_checkpoint_metrics"] == result["resumed_metrics"]
587+
), "Resumed training should produce same metrics as original training"
586588

587589
# Verify sampling ratio is approximately maintained for nested structure
588590
all_ids = []

tests/unit_tests/datasets/test_iterable_utils.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def generate_ckpt(
101101
pre_checkpoint_batches = batches[:steps_before_checkpoint]
102102
post_checkpoint_batches = batches[steps_before_checkpoint:]
103103

104+
# Compute metrics for post-checkpoint batches only
105+
post_checkpoint_metrics = all_metrics[len(checkpoint_metrics) :]
106+
104107
# Resume with new instance if provided
105108
resumed_batches = []
106109
resumed_metrics = []
@@ -127,24 +130,28 @@ def generate_ckpt(
127130
# Original run
128131
"pre_checkpoint_batches": pre_checkpoint_batches,
129132
"post_checkpoint_batches": post_checkpoint_batches,
130-
"metrics_at_checkpoint": keep_last_metric(checkpoint_metrics),
131-
"final_metrics": keep_last_metric(all_metrics),
133+
"metrics_at_checkpoint": aggregate_metrics(checkpoint_metrics),
134+
"post_checkpoint_metrics": aggregate_metrics(post_checkpoint_metrics),
135+
"final_metrics": aggregate_metrics(all_metrics),
132136
# Resumed run
133137
"resumed_batches": resumed_batches,
134-
"resumed_metrics": keep_last_metric(resumed_metrics),
138+
"resumed_metrics": aggregate_metrics(resumed_metrics),
135139
# Internal state for loading - only if someone needs to manually load
136140
"_checkpoint_state": checkpoint_state,
137141
}
138142

139143

140-
def keep_last_metric(metrics_list: list) -> dict[str, Any]:
141-
result = {}
144+
def aggregate_metrics(metrics_list: list) -> dict[str, Any]:
145+
"""Aggregate metrics according to their reduction types (SUM, MEAN, MAX, MIN, STD)."""
146+
if not metrics_list:
147+
return {}
148+
149+
accumulators = {}
150+
142151
for metric in metrics_list:
143-
# Expect observability.Metric objects only
144152
key = metric.key
145-
value = metric.value
146-
147-
# For test purposes, just keep the last value of each metric
148-
result[key] = value
153+
if key not in accumulators:
154+
accumulators[key] = metric.reduction.accumulator_class(metric.reduction)
155+
accumulators[key].append(metric.value)
149156

150-
return result
157+
return {key: acc.get_value() for key, acc in accumulators.items()}

0 commit comments

Comments
 (0)