Skip to content

Add Grain iterator checkpoint/resume and fix num_batches#22490

Open
MarcosAsh wants to merge 1 commit intokeras-team:masterfrom
MarcosAsh:grain-checkpoint-resume
Open

Add Grain iterator checkpoint/resume and fix num_batches#22490
MarcosAsh wants to merge 1 commit intokeras-team:masterfrom
MarcosAsh:grain-checkpoint-resume

Conversation

@MarcosAsh
Copy link
Copy Markdown
Contributor

Summary

Right now, if training with a Grain dataset gets interrupted and resumed via BackupAndRestore, the data pipeline restarts from the beginning: model weights are restored but the iterator position is lost. This means the model sees early data twice and skips later data entirely.

This PR fixes that by wiring Grain's built-in DatasetIterator.get_state() / set_state() through the training stack, so BackupAndRestore can save and restore the exact iterator position. The state is tiny (just {"next_index": 5}) and gets written into the existing training_metadata.json.

Also fixes num_batches for finite MapDataset it was hardcoded to None, which meant progress bars never showed a total. Now it returns the actual count via len().

Changes

  • DataAdapter base class gets optional get_iterator_state() / set_iterator_state() (no-op defaults, backward compatible)
  • GrainDatasetAdapter tracks the live iterator via a _TrackableIterable wrapper and implements the state methods
  • EpochIterator delegates state calls to the adapter
  • All three backend trainers expose the epoch_iterator on the model so callbacks can reach it
  • BackupAndRestore saves/restores iterator_state in themetadata JSON

Limitations

  • TF backend: get_tf_dataset() wraps Grain inside tf.data.Dataset.from_generator(), hiding the iterator. Checkpoint/resume works on JAX, numpy, and torch backends.
  • Grain's legacy DataLoader doesn't expose state methods, so it returns None.
  • Progress bar step numbers restart from 0 on the resumed epoch cosmetic, can be a follow-up.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the fault tolerance of training pipelines by enabling deterministic checkpoint and resume for Grain datasets. Previously, interruptions would cause data pipelines to restart from the beginning, leading to data re-exposure and skipped data. The changes introduce mechanisms to save and restore the exact state of data iterators, ensuring seamless continuation of training. Additionally, the num_batches property for finite Grain MapDatasets is now correctly reported, improving progress tracking. The PyTorch LSTM implementation also received an update, addressing gradient issues and improving weight handling.

Highlights

  • Grain Iterator Checkpoint/Resume: Implemented deterministic checkpoint and resume functionality for Grain dataset iterators, allowing training to restart from the exact iterator position rather than from the beginning.
  • Data Adapter State Management: Introduced get_iterator_state() and set_iterator_state() methods in the base DataAdapter class and EpochIterator for managing and restoring data pipeline state.
  • GrainDatasetAdapter Enhancements: Updated GrainDatasetAdapter to track live iterators via a new _TrackableIterable wrapper, enabling state capture and restoration. Also, the num_batches property for finite MapDataset instances now correctly reports the actual count.
  • Trainer Integration: Exposed the _epoch_iterator attribute in JAX, TensorFlow, and PyTorch trainers, making the live iterator accessible to callbacks like BackupAndRestore for state persistence.
  • BackupAndRestore Callback Update: Integrated iterator state saving and restoring into the BackupAndRestore callback, ensuring that the data pipeline state is preserved alongside model weights.
  • PyTorch LSTM Backend Refinement: Refactored the PyTorch LSTM implementation to use torch._VF.lstm directly, improving weight handling and resolving previous gradient issues with CuDNN, leading to the removal of related xfail markers in tests.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Enable deterministic mid-epoch resume for Grain datasets by saving and
restoring the DatasetIterator state through BackupAndRestore. Also fix
num_batches to return the actual count for finite MapDatasets so progress
bars work correctly.
@MarcosAsh MarcosAsh force-pushed the grain-checkpoint-resume branch from 68ab6e9 to 0ff26a8 Compare March 25, 2026 00:56
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces data-pipeline state checkpointing and restoration capabilities, primarily for Grain datasets. It exposes the _epoch_iterator in JAX, TensorFlow, and PyTorch trainers, allowing the BackupAndRestore callback to save and restore the iterator's state. New get_iterator_state and set_iterator_state methods are added to the DataAdapter and EpochIterator classes, with the GrainDatasetAdapter implementing these for MapDataset and IterDataset via a new _TrackableIterable wrapper. Additionally, the PyTorch LSTM backend is refactored to use torch._VF.lstm directly, fixing previous cuDNN issues related to gradients and masking, which is reflected in the updated lstm_test.py. I have no feedback to provide on the review comments as none were provided.

@MarcosAsh
Copy link
Copy Markdown
Contributor Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces robust fault tolerance for data pipelines, particularly for Grain datasets, by enabling the saving and restoring of data iterator states. The changes involve exposing the _epoch_iterator in the JAX, TensorFlow, and PyTorch trainers, and implementing get_iterator_state and set_iterator_state methods in the DataAdapter and EpochIterator classes. The GrainDatasetAdapter is updated to manage and persist the iterator state for grain.MapDataset and grain.IterDataset, including a new _TrackableIterable wrapper. The BackupAndRestore callback now utilizes this functionality to save and restore the data pipeline state. A review comment suggests adding a warning when an iterator state cannot be restored due to a missing set_state method, to provide better user feedback and adhere to API design guidelines regarding detailed error messages and anticipating common mistakes.

Comment on lines +31 to +34
if self._adapter._pending_iterator_state is not None:
if hasattr(it, "set_state"):
it.set_state(self._adapter._pending_iterator_state)
self._adapter._pending_iterator_state = None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If a pending iterator state exists but the new iterator does not have a set_state method, the state is silently discarded. This could lead to unexpected behavior where the data pipeline restarts from the beginning of the epoch without any warning to the user.

Consider adding a warning in the else case to inform the user that the iterator state could not be restored. This would improve debuggability for users who might inadvertently change their data pipeline in a way that makes it non-resumable.

Suggested change
if self._adapter._pending_iterator_state is not None:
if hasattr(it, "set_state"):
it.set_state(self._adapter._pending_iterator_state)
self._adapter._pending_iterator_state = None
if self._adapter._pending_iterator_state is not None:
if hasattr(it, "set_state"):
it.set_state(self._adapter._pending_iterator_state)
else:
import warnings
warnings.warn(
"Attempted to restore a data iterator state, but the "
"iterator does not support it. The data pipeline will "
"restart from the beginning of the epoch.",
stacklevel=2,
)
self._adapter._pending_iterator_state = None
References
  1. The Keras API design guidelines state to 'Provide detailed feedback messages upon user error' (line 140) and 'anticipate common mistakes' (line 139). Silently failing to restore an iterator's state is a potential source of confusion for users, and a warning would provide helpful feedback. (link)

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 25, 2026

Codecov Report

❌ Patch coverage is 77.41935% with 14 lines in your changes missing coverage. Please review.
✅ Project coverage is 83.16%. Comparing base (3b2e3f8) to head (0ff26a8).

Files with missing lines Patch % Lines
keras/src/callbacks/backup_and_restore.py 30.00% 4 Missing and 3 partials ⚠️
...rc/trainers/data_adapters/grain_dataset_adapter.py 89.18% 2 Missing and 2 partials ⚠️
keras/src/trainers/epoch_iterator.py 60.00% 2 Missing ⚠️
keras/src/trainers/data_adapters/data_adapter.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22490      +/-   ##
==========================================
- Coverage   83.17%   83.16%   -0.01%     
==========================================
  Files         596      596              
  Lines       67610    67670      +60     
  Branches    10531    10543      +12     
==========================================
+ Hits        56234    56280      +46     
- Misses       8655     8664       +9     
- Partials     2721     2726       +5     
Flag Coverage Δ
keras 82.99% <77.41%> (-0.01%) ⬇️
keras-jax 59.96% <67.74%> (+<0.01%) ⬆️
keras-numpy 54.26% <54.83%> (-0.01%) ⬇️
keras-openvino 51.17% <54.83%> (+<0.01%) ⬆️
keras-tensorflow 61.25% <66.12%> (+<0.01%) ⬆️
keras-torch 60.11% <69.35%> (+0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@keerthanakadiri keerthanakadiri self-assigned this Mar 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants