Skip to content

Conversation

@dougalm
Copy link
Collaborator

@dougalm dougalm commented Nov 19, 2025

No description provided.

@dougalm dougalm added the pull ready Ready for copybara import and testing label Nov 19, 2025
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @dougalm, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request fundamentally refactors how JaxprInputEffect and its derived effects are handled within JAX. By removing the input_index from these effects, the system moves away from positional effect tracking, simplifying JAXPR transformations, control flow operations, and state management. This change aims to reduce complexity and potential errors associated with re-indexing effects when variables are reordered or transformed, leading to a more robust and maintainable effect system.

Highlights

  • Removal of Positional Effects: The input_index attribute has been removed from JaxprInputEffect and its subclasses (like ReadEffect, WriteEffect, AccumEffect). This means these effects are no longer tied to a specific input position.
  • Simplified Effect Handling: Numerous functions across JAX's core, partial evaluation, control flow, and state management modules have been updated to reflect the non-positional nature of JaxprInputEffect. This includes removing _renumber_effects and related logic that previously adjusted effect indices during JAXPR transformations.
  • Streamlined Control Flow Logic: Code related to handling effects in control flow primitives (like cond and while loops) has been simplified, removing complex index adjustments and checks for conflicting writes that are no longer necessary with non-positional effects.
  • Updated State Primitives: The abstract evaluation for state primitives (get, swap, addupdate) now creates effects without an input_index, aligning with the new effect definition.
  • Test Updates: Existing tests for state effects have been updated to assert the presence of non-positional effects, confirming the change in behavior.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request experiments with removing positional information from effects by removing input_index from JaxprInputEffect. This is a significant and wide-ranging change that simplifies code across many modules, removing the need for complex effect re-numbering logic in various transformations like vmap, jit, and control flow primitives. The changes look consistent with the goal of making effects non-positional.

However, I found a critical issue in jax/_src/state/discharge.py where the logic for propagating effects from a run_state call seems to have been completely removed, causing it to incorrectly report no effects. This could lead to incorrect effect analysis for functions that use run_state.

pass
inner_to_outer_aval_mapping[i] = outer_ref_index
outer_ref_index += 1
nonlocal_effects = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The logic for propagating effects from the run_state jaxpr seems to have been completely removed. This results in nonlocal_effects always being an empty set, which is incorrect as run_state should propagate effects from its body, especially for non-discharged refs.

With the removal of positional effects, we can no longer map effects to specific input refs. A reasonable approach would be to over-approximate: if there are any non-discharged Ref inputs, all RefEffects from the inner jaxpr should be propagated. Non-RefEffects should always be propagated.

I suggest restoring a simplified version of the effect propagation logic. The is_ref variable on the next line will become unused and can be removed separately.

Suggested change
nonlocal_effects = set()
nonlocal_effects = jaxpr.effects if any(isinstance(aval, AbstractRef) for aval in avals) else {e for e in jaxpr.effects if not isinstance(e, RefEffect)}

@dougalm dougalm force-pushed the remove-positional-effects branch from 60215a7 to 1801635 Compare November 20, 2025 15:11
@dougalm dougalm force-pushed the remove-positional-effects branch from 1801635 to 161b183 Compare November 20, 2025 15:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant