Skip to content

Gradient of jax.lax.reduce assumes computation is commutative #32474

@tomsmeding

Description

@tomsmeding

Description

The documentation for jax.lax.reduce currently states that the computation (let's call it f) must form a monoid with the init_values, i.e. f must be associative and the initial value must be an identity for f. The documentation does not currently require f to be commutative.

However, jax.grad does assume that f is commutative. This can be observed either by looking at the generated jaxpr or by testing with a carefully-crafted associative-but-not-commutative function:

import jax
def f(arr): return jax.lax.reduce(arr, 1.0, lambda x, y: x * y ** jnp.sign(x), [0])
inp = jnp.array([1,-2,3,-4], dtype=float)
print(f(inp))
print(jax.value_and_grad(f)(inp))

yields:

0.16666667
(Array(1.5, dtype=float32), Array([ 1.5  , -0.75 ,  0.5  ,  0.375], dtype=float32))

where the result of f is unequal to the value returned by value_and_grad.

Note that f does indeed have 1.0 as identity (writing $\sigma(x)$ for sign(x)): $1 \cdot x^{\sigma(1)} = x^1 = x$ and $x \cdot 1^{\sigma(x)} = x \cdot 1 = x$. Furthermore, f is associative:
$f(f(x,y),z) = (x \cdot y^{\sigma(x)}) \cdot z^{\sigma(x \cdot y^{\sigma(x)})} = x \cdot y^{\sigma(x)} \cdot z^{\sigma(x) \sigma(y)} = x \cdot (y \cdot z^{\sigma(y)})^{\sigma(x)} = f(x,f(y,z))$
and f is clearly not commutative.

EDIT: the proofs above hold only if none of the inputs are zero. The argument as a whole still holds, though.

If I'm not mistaken, the documentation for jax.lax.reduce should note that commutativity is assumed.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.7.2
jaxlib: 0.7.2
numpy:  2.3.3
python: 3.13.7 (main, Aug 15 2025, 12:34:02) [GCC 15.2.1 20250813]
device info: cpu-1, 1 local devices"

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions