Skip to content

Conversation

@nandan2003
Copy link

Description

Fixes #29744.

jax.block_until_ready(x) previously performed a no-op silently if x contained no JAX arrays (e.g., passing a Python integer or list of integers). This led users to believe synchronization had occurred when it had not.

This PR adds a check to track if block_until_ready() was successfully called on any leaf of the input. If the input traversal finishes without blocking on anything, a UserWarning is issued.

This implements the "opt-in/deprecation cycle" approach suggested in the issue, rather than immediately raising an error which could be a breaking change.

Reproduction

import jax
# Before: Prints nothing, returns immediately
# After: Prints UserWarning
jax.block_until_ready(100)

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nandan2003, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the robustness of jax.block_until_ready() by providing explicit feedback to users when the function's intended synchronization behavior is not met. It introduces a warning mechanism to alert users about calls that result in a no-op, thereby improving clarity and preventing potential misinterpretations of program execution flow.

Highlights

  • Improved block_until_ready behavior: The jax.block_until_ready() function now issues a UserWarning if it is called on an input that contains no JAX arrays or objects with a block_until_ready() method, and thus performs no actual blocking operation.
  • Addresses silent no-op: This change prevents users from mistakenly believing synchronization has occurred when block_until_ready() silently did nothing on non-JAX objects, such as Python integers or lists.
  • Opt-in/deprecation cycle: The implementation follows an opt-in/deprecation cycle by issuing a warning rather than immediately raising an error, avoiding a breaking change for existing code.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds a useful warning when jax.block_until_ready is called on an input that contains no JAX arrays or other blockable objects, preventing a silent no-op that could be misleading. The implementation is correct. I've added one suggestion to refactor the logic for improved readability and to adhere to Python's import conventions.

@Gattocrucco
Copy link
Contributor

Gattocrucco commented Nov 24, 2025

Do you think this should be an AND or an OR? I mean: you currently raise the warning if no leaf was blocked; should you instead raise the warning if any leaf was not blocked? My intuition is that the best behavior is the second but I don't have a strong opinion.

@nandan2003
Copy link
Author

Thanks, GattoCRUCCo. I agree that the second behavior is a much better guardrail. If any non-blockable object is present in a PyTree that also contains arrays, it indicates a structural error in the user's code that should be surfaced.

I will update the implementation to fire the warning if any leaf in the input PyTree was a non-JAX object that was silently ignored.

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 24, 2025

I'm not sure this change is desirable – a common pattern is to call jax.block_until_ready(tree) so that any arrays in the tree will block, and any non-arrays will pass through unchanged. What would be the alternative to that pattern if we merge this PR?

@nandan2003 nandan2003 force-pushed the fix-block-until-ready branch from 20088c1 to 1b8acb9 Compare November 24, 2025 15:44
@nandan2003
Copy link
Author

The logic has been finalized to implement the permissive guardrail. The PR now only issues a UserWarning if the input PyTree contains zero blockable objects (total no-op). This avoids false warnings for valid mixed use cases ([array, int]) while still fixing the original silent failure bug.

@jakevdp
Copy link
Collaborator

jakevdp commented Nov 24, 2025

I still don't think we'd want this change. I think it would be confusing if block_until_ready(tree) produces a warning if you happen to pass a tree that doesn't contain a JAX object.

Put yourself in the user's shoes: what would be the recourse when you see this warning? You'd have to defensively write something like this every time you were passing a tree to block_until_ready:

if not tree_all(lambda obj: not isinstance(obj, jax.Array), tree):
  tree = jax.block_until_ready(tree)

...otherwise you risk hitting a warning if your tree changes its contents.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

jax.block_until_ready() should error on non-arrays by default

3 participants