Skip to content

Commit d9ea30e

Browse files
author
Felipe Mello
committed
improve docstring
1 parent 0e4bdc3 commit d9ea30e

File tree

1 file changed

+12
-18
lines changed

1 file changed

+12
-18
lines changed

src/forge/data/utils.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -305,36 +305,30 @@ def __next__(self) -> dict:
305305
return current_batch
306306

307307

308-
def extract_epoch_from_batch(batch: dict | list) -> int:
309-
"""Extract epoch number from batch metrics.
308+
def extract_epoch_from_batch(batch: dict) -> int:
309+
"""Extract epoch number from batch metrics. Useful to detect epoch changes during validation,
310+
where we want to run exactly one epoch.
310311
311-
Assumes datasets inherit from InfiniteTuneIterableDataset which always
312-
adds num_epochs metric. Raises clear error if assumption is violated.
312+
Assumes the dataset adds "num_epochs" Metric to teh sample, where one epoch is incremented on dataset exhaustion.
313+
For an example, check forge.src.data.datasets.HfIterableDataset.
313314
314315
Args:
315-
batch: Batch dictionary with 'metrics' field OR list of sample dicts
316+
batch (dict): Batch dictionary with 'metrics' field
316317
317318
Returns:
318-
Epoch number from metrics
319+
int: Max epoch number from metrics
319320
320321
Raises:
321-
ValueError: If metrics missing or no num_epochs found
322+
ValueError: If metrics key is missing or not metric `num_epochs` found
322323
"""
323-
# Handle list of samples (uncollated batches)
324-
if isinstance(batch, list):
325-
if not batch:
326-
raise ValueError("Empty batch provided")
327-
batch = batch[0] # Extract first sample
328-
329324
if "metrics" not in batch:
330325
raise ValueError(
331-
"Batch missing 'metrics' field. Ensure dataset inherits from "
332-
"InfiniteTuneIterableDataset which adds this automatically."
326+
"Batch missing 'metrics' field. Cannot extract epoch from batch."
333327
)
334328

335-
for metric in batch["metrics"]:
336-
if "num_epochs" in metric.key:
337-
return int(metric.value)
329+
epochs = [metric.value for metric in batch["metrics"] if metric.key == "num_epochs"]
330+
if epochs:
331+
return max(epochs)
338332

339333
raise ValueError(
340334
f"No 'num_epochs' metric found in batch. Got metrics: "

0 commit comments

Comments
 (0)