diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 9a63de13890f..729a31c04b1b 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -7724,13 +7724,13 @@ def _reduce_tree(*xs, axis=0): n = xs[0].shape[axis] n1 = (n + 1) // 2 n2 = n - n1 - xs1 = [slicing.slice_in_dim(x, 0, n1) for x in xs] - xs2 = [slicing.slice_in_dim(x, n1, None) for x in xs] + xs1 = [slicing.slice_in_dim(x, stride=2) for x in xs] + xs2 = [slicing.slice_in_dim(x, 1, stride=2) for x in xs] if n2 != n1: paddings = [(0, 0, 0)] * len(xs[0].shape) paddings[axis] = (0, 1, 0) xs2 = [pad(x2, i, paddings) for x2, i in zip(xs2, init_values)] - xs = reducer(*(xs1 + xs2)) + xs = reducer(*xs1, *xs2) if xs[0].shape[axis] == 0: return [full(input_shape[non_axes], i) for i in init_values] return tuple(squeeze(x, (axis,)) for x in xs) diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 9878e8ed9435..51777db24a99 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -939,8 +939,8 @@ def index_take(src: Array, idxs: Array, axes: Sequence[int]) -> Array: ### convenience wrappers around traceables -def slice_in_dim(operand: Array | np.ndarray, start_index: int | None, - limit_index: int | None, +def slice_in_dim(operand: Array | np.ndarray, start_index: int | None = 0, + limit_index: int | None = None, stride: int = 1, axis: int = 0) -> Array: """Convenience wrapper around :func:`lax.slice` applying to only one dimension. diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index a6398e402df9..3a9b9a4e7f2d 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -25,6 +25,7 @@ import numpy as np import jax +import jax.numpy as jnp from jax import dtypes from jax import lax from jax._src import test_util as jtu @@ -765,6 +766,15 @@ def op(xs, ys): reduce = lambda xs, ys: lax.reduce((xs, ys), init_vals, op, dims) check_grads(reduce, operands, 2, ["fwd", "rev"], tol, tol) + def test_reduce_grad_doesnt_reorder(self): + # https://github.com/jax-ml/jax/issues/32474 + def f(arr): + return jax.lax.reduce(arr, 1.0, lambda x, y: x * y ** jnp.sign(x), [0]) + inp = jnp.array([1,-2,3,-4], dtype=float) + ans, _ = jax.jvp(f, (inp,), (inp,)) + expected = f(inp) + self.assertAllClose(ans, expected, check_dtypes=False) + @jtu.sample_product( [dict(init_val=init_val, op=op, dtype=dtype, rng_factory=rng_factory, shape=shape, dims=dims, strides=strides, padding=padding,