-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Fix Ref indexing discrepancies with array indexing #33456
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix Ref indexing discrepancies with array indexing #33456
Conversation
Summary of ChangesHello @yashwantbezawada, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request aims to improve the consistency and robustness of Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces several fixes to make Ref indexing more consistent with JAX array indexing. It correctly handles out-of-bounds empty slices, adds support for None to insert new axes, and improves the error message for negative slice steps. The changes are well-implemented. I have one suggestion to refactor a piece of duplicated logic to improve code maintainability.
0fc6707 to
e11f661
Compare
|
Updated this PR to fully fix all the issues from #33322. Added RefFlip transform to handle negative slice steps - these get converted to positive equivalents at the RefIndexer level and then flipped to get the right element order. Also updated discharge.py to handle both RefNewAxis and RefFlip in the transform functions. Refactored the duplicated code in RefIndexer.getitem as suggested in the review. All four test cases from the issue should work now (OOB clamping, None indexing, ellipsis+None, and negative slices). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a great step towards aligning Ref indexing with standard JAX array indexing. The changes to allow empty slices, support None for new axis insertion, and improve the error message for negative steps are all valuable improvements. The implementation using RefNewAxis and RefFlip transforms is a clean and effective approach. My review includes a few suggestions to refactor some loops into single function calls for better performance and to break down a particularly complex method for improved readability. Overall, this is a solid contribution.
This addresses the issues in jax-ml#33322 where Ref indexing behaved differently from JAX array indexing: 1. OOB slice clamping: Allow empty slices when start equals dim instead of raising an error 2. None indexing: Add RefNewAxis transform to handle np.newaxis in indices, enabling x[None] and x[..., None] 3. Negative slice steps: Convert negative step slices to positive equivalents and apply RefFlip transform to reverse the result 4. Updated discharge.py to handle the new transforms in both transform_array and transform_swap_array 5. Refactored RefIndexer.__getitem__ to reduce code duplication Fixes jax-ml#33322.
e11f661 to
5094008
Compare
This addresses the issues raised in #33322 where Ref indexing behaves differently from regular JAX array indexing.
Changes:
slice(11, 12, 1)on size-10 array now returns empty result instead of erroring)Noneindexing to insert new axes (e.g.,x[None],x[..., None])x[::-1]) by converting them to positive steps and applying a flip transformRefIndexer.__getitem__into helper functions for better maintainabilityFor
Noneindexing, added aRefNewAxistransform that tracks where new axes should be inserted. For negative steps, added aRefFliptransform that reverses the specified axes after indexing with the positive-step equivalent.Fixes #33322.