@@ -305,36 +305,30 @@ def __next__(self) -> dict:
305305 return current_batch
306306
307307
308- def extract_epoch_from_batch (batch : dict | list ) -> int :
309- """Extract epoch number from batch metrics.
308+ def extract_epoch_from_batch (batch : dict ) -> int :
309+ """Extract epoch number from batch metrics. Useful to detect epoch changes during validation,
310+ where we want to run exactly one epoch.
310311
311- Assumes datasets inherit from InfiniteTuneIterableDataset which always
312- adds num_epochs metric. Raises clear error if assumption is violated .
312+ Assumes the dataset adds "num_epochs" Metric to teh sample, where one epoch is incremented on dataset exhaustion.
313+ For an example, check forge.src.data.datasets.HfIterableDataset .
313314
314315 Args:
315- batch: Batch dictionary with 'metrics' field OR list of sample dicts
316+ batch (dict) : Batch dictionary with 'metrics' field
316317
317318 Returns:
318- Epoch number from metrics
319+ int: Max epoch number from metrics
319320
320321 Raises:
321- ValueError: If metrics missing or no num_epochs found
322+ ValueError: If metrics key is missing or not metric ` num_epochs` found
322323 """
323- # Handle list of samples (uncollated batches)
324- if isinstance (batch , list ):
325- if not batch :
326- raise ValueError ("Empty batch provided" )
327- batch = batch [0 ] # Extract first sample
328-
329324 if "metrics" not in batch :
330325 raise ValueError (
331- "Batch missing 'metrics' field. Ensure dataset inherits from "
332- "InfiniteTuneIterableDataset which adds this automatically."
326+ "Batch missing 'metrics' field. Cannot extract epoch from batch."
333327 )
334328
335- for metric in batch ["metrics" ]:
336- if "num_epochs" in metric . key :
337- return int ( metric . value )
329+ epochs = [ metric . value for metric in batch ["metrics" ] if metric . key == "num_epochs" ]
330+ if epochs :
331+ return max ( epochs )
338332
339333 raise ValueError (
340334 f"No 'num_epochs' metric found in batch. Got metrics: "
0 commit comments