Skip to content

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Oct 10, 2025

fixes #32474

@mattjj
Copy link
Collaborator Author

mattjj commented Oct 10, 2025

Passes on CPU and TPU but not GPU? :think-cry:

@tomsmeding
Copy link

While I posted the issue that led to this PR (thanks!), I don't really use JAX in practice, so I'm not sure I'm the right person to review or test any of this. :)

Just one thought I have: it's possible that the new strided version is slower than the original. The original had only non-strided slices, which are probably compiled to a memcpy-like operation at some point; with striding, the memory gathers cannot be vectorised nearly as well and may well be slower.
On a high level, this is not surprising: it is known in the (functional) parallel computing world that a parallel reduction can be faster if it is allowed to assume commutativity. This is, for example, also why Futhark has a separate reduction primitive for commutative operators.

@mattjj
Copy link
Collaborator Author

mattjj commented Oct 10, 2025

Great points, @tomsmeding ! Yes I think if we knew of some users (or developers) who were interested in this, we would add more structure here, like separate commutative- and noncommutative-reduces.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Gradient of jax.lax.reduce assumes computation is commutative

2 participants