-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Warn when block_until_ready is called on non-JAX objects #29744 #33498
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @nandan2003, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the robustness of Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds a useful warning when jax.block_until_ready is called on an input that contains no JAX arrays or other blockable objects, preventing a silent no-op that could be misleading. The implementation is correct. I've added one suggestion to refactor the logic for improved readability and to adhere to Python's import conventions.
|
Do you think this should be an AND or an OR? I mean: you currently raise the warning if no leaf was blocked; should you instead raise the warning if any leaf was not blocked? My intuition is that the best behavior is the second but I don't have a strong opinion. |
|
Thanks, GattoCRUCCo. I agree that the second behavior is a much better guardrail. If any non-blockable object is present in a PyTree that also contains arrays, it indicates a structural error in the user's code that should be surfaced. I will update the implementation to fire the warning if any leaf in the input PyTree was a non-JAX object that was silently ignored. |
|
I'm not sure this change is desirable – a common pattern is to call |
20088c1 to
1b8acb9
Compare
|
The logic has been finalized to implement the permissive guardrail. The PR now only issues a |
|
I still don't think we'd want this change. I think it would be confusing if Put yourself in the user's shoes: what would be the recourse when you see this warning? You'd have to defensively write something like this every time you were passing a tree to ...otherwise you risk hitting a warning if your tree changes its contents. |
Description
Fixes #29744.
jax.block_until_ready(x)previously performed a no-op silently ifxcontained no JAX arrays (e.g., passing a Python integer or list of integers). This led users to believe synchronization had occurred when it had not.This PR adds a check to track if
block_until_ready()was successfully called on any leaf of the input. If the input traversal finishes without blocking on anything, aUserWarningis issued.This implements the "opt-in/deprecation cycle" approach suggested in the issue, rather than immediately raising an error which could be a breaking change.
Reproduction