Add Grain iterator checkpoint/resume and fix num_batches#22490
Add Grain iterator checkpoint/resume and fix num_batches#22490MarcosAsh wants to merge 1 commit intokeras-team:masterfrom
Conversation
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the fault tolerance of training pipelines by enabling deterministic checkpoint and resume for Grain datasets. Previously, interruptions would cause data pipelines to restart from the beginning, leading to data re-exposure and skipped data. The changes introduce mechanisms to save and restore the exact state of data iterators, ensuring seamless continuation of training. Additionally, the Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
Enable deterministic mid-epoch resume for Grain datasets by saving and restoring the DatasetIterator state through BackupAndRestore. Also fix num_batches to return the actual count for finite MapDatasets so progress bars work correctly.
68ab6e9 to
0ff26a8
Compare
There was a problem hiding this comment.
Code Review
This pull request introduces data-pipeline state checkpointing and restoration capabilities, primarily for Grain datasets. It exposes the _epoch_iterator in JAX, TensorFlow, and PyTorch trainers, allowing the BackupAndRestore callback to save and restore the iterator's state. New get_iterator_state and set_iterator_state methods are added to the DataAdapter and EpochIterator classes, with the GrainDatasetAdapter implementing these for MapDataset and IterDataset via a new _TrackableIterable wrapper. Additionally, the PyTorch LSTM backend is refactored to use torch._VF.lstm directly, fixing previous cuDNN issues related to gradients and masking, which is reflected in the updated lstm_test.py. I have no feedback to provide on the review comments as none were provided.
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces robust fault tolerance for data pipelines, particularly for Grain datasets, by enabling the saving and restoring of data iterator states. The changes involve exposing the _epoch_iterator in the JAX, TensorFlow, and PyTorch trainers, and implementing get_iterator_state and set_iterator_state methods in the DataAdapter and EpochIterator classes. The GrainDatasetAdapter is updated to manage and persist the iterator state for grain.MapDataset and grain.IterDataset, including a new _TrackableIterable wrapper. The BackupAndRestore callback now utilizes this functionality to save and restore the data pipeline state. A review comment suggests adding a warning when an iterator state cannot be restored due to a missing set_state method, to provide better user feedback and adhere to API design guidelines regarding detailed error messages and anticipating common mistakes.
| if self._adapter._pending_iterator_state is not None: | ||
| if hasattr(it, "set_state"): | ||
| it.set_state(self._adapter._pending_iterator_state) | ||
| self._adapter._pending_iterator_state = None |
There was a problem hiding this comment.
If a pending iterator state exists but the new iterator does not have a set_state method, the state is silently discarded. This could lead to unexpected behavior where the data pipeline restarts from the beginning of the epoch without any warning to the user.
Consider adding a warning in the else case to inform the user that the iterator state could not be restored. This would improve debuggability for users who might inadvertently change their data pipeline in a way that makes it non-resumable.
| if self._adapter._pending_iterator_state is not None: | |
| if hasattr(it, "set_state"): | |
| it.set_state(self._adapter._pending_iterator_state) | |
| self._adapter._pending_iterator_state = None | |
| if self._adapter._pending_iterator_state is not None: | |
| if hasattr(it, "set_state"): | |
| it.set_state(self._adapter._pending_iterator_state) | |
| else: | |
| import warnings | |
| warnings.warn( | |
| "Attempted to restore a data iterator state, but the " | |
| "iterator does not support it. The data pipeline will " | |
| "restart from the beginning of the epoch.", | |
| stacklevel=2, | |
| ) | |
| self._adapter._pending_iterator_state = None |
References
- The Keras API design guidelines state to 'Provide detailed feedback messages upon user error' (line 140) and 'anticipate common mistakes' (line 139). Silently failing to restore an iterator's state is a potential source of confusion for users, and a warning would provide helpful feedback. (link)
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## master #22490 +/- ##
==========================================
- Coverage 83.17% 83.16% -0.01%
==========================================
Files 596 596
Lines 67610 67670 +60
Branches 10531 10543 +12
==========================================
+ Hits 56234 56280 +46
- Misses 8655 8664 +9
- Partials 2721 2726 +5
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Summary
Right now, if training with a Grain dataset gets interrupted and resumed via
BackupAndRestore, the data pipeline restarts from the beginning: model weights are restored but the iterator position is lost. This means the model sees early data twice and skips later data entirely.This PR fixes that by wiring Grain's built-in
DatasetIterator.get_state()/set_state()through the training stack, soBackupAndRestorecan save and restore the exact iterator position. The state is tiny (just{"next_index": 5}) and gets written into the existingtraining_metadata.json.Also fixes
num_batchesfor finiteMapDatasetit was hardcoded toNone, which meant progress bars never showed a total. Now it returns the actual count vialen().Changes
DataAdapterbase class gets optionalget_iterator_state()/set_iterator_state()(no-op defaults, backward compatible)GrainDatasetAdaptertracks the live iterator via a_TrackableIterablewrapper and implements the state methodsEpochIteratordelegates state calls to the adapterepoch_iteratoron the model so callbacks can reach itBackupAndRestoresaves/restoresiterator_statein themetadata JSONLimitations
get_tf_dataset()wraps Grain insidetf.data.Dataset.from_generator(), hiding the iterator. Checkpoint/resume works on JAX, numpy, and torch backends.DataLoaderdoesn't expose state methods, so it returnsNone.