-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Description
Description
The documentation for jax.lax.reduce currently states that the computation
(let's call it f
) must form a monoid with the init_values
, i.e. f
must be associative and the initial value must be an identity for f
. The documentation does not currently require f
to be commutative.
However, jax.grad
does assume that f
is commutative. This can be observed either by looking at the generated jaxpr or by testing with a carefully-crafted associative-but-not-commutative function:
import jax
def f(arr): return jax.lax.reduce(arr, 1.0, lambda x, y: x * y ** jnp.sign(x), [0])
inp = jnp.array([1,-2,3,-4], dtype=float)
print(f(inp))
print(jax.value_and_grad(f)(inp))
yields:
0.16666667
(Array(1.5, dtype=float32), Array([ 1.5 , -0.75 , 0.5 , 0.375], dtype=float32))
where the result of f
is unequal to the value returned by value_and_grad
.
Note that f
does indeed have 1.0 as identity (writing sign(x)
): f
is associative:
and f
is clearly not commutative.
EDIT: the proofs above hold only if none of the inputs are zero. The argument as a whole still holds, though.
If I'm not mistaken, the documentation for jax.lax.reduce
should note that commutativity is assumed.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.7.2
jaxlib: 0.7.2
numpy: 2.3.3
python: 3.13.7 (main, Aug 15 2025, 12:34:02) [GCC 15.2.1 20250813]
device info: cpu-1, 1 local devices"