Skip to content

Update orbax guide#2308

Merged
hertschuh merged 6 commits intokeras-team:masterfrom
amitsrivastava78:update-orbax-guide
Mar 4, 2026
Merged

Update orbax guide#2308
hertschuh merged 6 commits intokeras-team:masterfrom
amitsrivastava78:update-orbax-guide

Conversation

@amitsrivastava78
Copy link
Contributor

@amitsrivastava78 amitsrivastava78 commented Mar 4, 2026

Replaces the existing Orbax checkpointing guide (which required users to define and copy-paste custom KerasOrbaxCheckpointManager and OrbaxCheckpointCallback wrapper classes) with a comprehensive guide for the built-in keras.callbacks.OrbaxCheckpoint callback.

What changed

  • Removed ~100 lines of manual wrapper boilerplate — the built-in callback handles everything.
  • All examples use the public keras.callbacks.OrbaxCheckpoint API directly.

Sections covered

  1. Basic Usage — drop-in callback with model.fit()
  2. Loading a modelkeras.saving.load_model()
  3. Loading weights onlymodel.load_weights()
  4. Resuming training — step recovery from optimizer iterations
  5. Save best onlymonitor, mode, save_best_only
  6. Batch-level checkpointingsave_freq=N
  7. Distributed training — Keras Distribution API with ModelParallel
  8. Cross-layout resharding — load under a different LayoutMap
  9. Callback parameters reference — full parameter table

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request overhauls the Orbax checkpointing guide for Keras, transitioning from a custom callback approach to utilizing the newly integrated keras.callbacks.OrbaxCheckpoint. The update provides a more streamlined and feature-rich experience for users, offering extensive examples that demonstrate its capabilities across various scenarios, including complex distributed training setups and flexible cross-layout resharding. The aim is to provide clear, practical guidance for saving and restoring model states efficiently within the JAX backend.

Highlights

  • Updated Orbax Checkpointing Guide: The guide for Orbax checkpointing in Keras has been significantly updated to reflect the usage of the built-in keras.callbacks.OrbaxCheckpoint.
  • Removal of Custom Implementations: The previously required custom KerasOrbaxCheckpointManager and OrbaxCheckpointCallback classes have been removed, simplifying the integration of Orbax checkpointing.
  • Expanded Examples and Features: New sections and examples have been added covering basic usage, loading models and weights, resuming training, save_best_only functionality, batch-level checkpointing, distributed training with model parallelism, and cross-layout resharding.
  • Distributed Training and Resharding: The guide now includes detailed demonstrations of how OrbaxCheckpoint works seamlessly with the Keras Distribution API, supporting multi-host environments and automatic resharding of checkpoints when loading under different layout maps.
  • Callback Parameters Reference: A comprehensive table outlining all available parameters for the keras.callbacks.OrbaxCheckpoint has been included for easy reference.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • guides/orbax_checkpoint.py
    • Replaced custom KerasOrbaxCheckpointManager and OrbaxCheckpointCallback with the built-in keras.callbacks.OrbaxCheckpoint.
    • Expanded the guide with new sections on basic usage, loading models, loading weights, resuming training, save_best_only, and batch-level checkpointing.
    • Added comprehensive examples for distributed training with model parallelism, including DeviceMesh and LayoutMap configurations.
    • Demonstrated cross-layout resharding capabilities for loading checkpoints under different sharding layouts.
    • Included a reference table for OrbaxCheckpoint callback parameters.
    • Updated guide metadata, including author, creation, and last modified dates, and a more descriptive summary.
    • Modified the example model architecture and training data size for clarity and robustness.
    • Added JAX-specific environment variable settings for simulating multi-device environments.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request provides a comprehensive update to the Orbax checkpointing guide. The previous version, which relied on custom-built callback classes, has been replaced with a new guide that leverages the official, built-in keras.callbacks.OrbaxCheckpoint. The updated guide is much more thorough, covering basic usage, resuming training, distributed training with the Keras Distribution API, and cross-layout resharding. The changes significantly improve the quality and usefulness of the guide. I have one minor suggestion to reduce code duplication for improved clarity.

@amitsrivastava78
Copy link
Contributor Author

@hertschuh The guide is ready for review. PTAL

Copy link
Contributor

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved with a couple nitpicks:

@amitsrivastava78
Copy link
Contributor Author

Approved with a couple nitpicks:

Thanks for approving, have updated as per the comments

@hertschuh hertschuh merged commit 73e0ac8 into keras-team:master Mar 4, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants