Skip to content

Commit 0922feb

Browse files
committed
Use a broadcasted gather in the sort JVP, rather than forming explicit iotas.
Use an unsigned index and promise that it is in bounds.
1 parent ef06607 commit 0922feb

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

jax/_src/lax/lax.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5506,15 +5506,26 @@ def _operands_to_keys(*operands, num_keys=1):
55065506

55075507
def _sort_jvp(primals, tangents, *, dimension, is_stable, num_keys):
55085508
shape = primals[0].shape
5509-
iotas = []
5510-
for dim, size in enumerate(shape):
5511-
iotas.append(broadcasted_iota(np.int64, shape, dim))
55125509
sorted_primals_and_idx = sort_p.bind(
5513-
*primals, iotas[dimension], dimension=dimension,
5514-
is_stable=is_stable, num_keys=num_keys)
5515-
idx = tuple(sorted_primals_and_idx[-1] if i == dimension else iotas[i]
5516-
for i in range(len(shape)))
5517-
tangents_out = tuple(t if type(t) is ad_util.Zero else t[idx] for t in tangents)
5510+
*primals, broadcasted_iota(np.uint64, shape, dimension),
5511+
dimension=dimension, is_stable=is_stable, num_keys=num_keys)
5512+
batch_dims = tuple(np.delete(np.arange(len(shape), dtype=np.int64),
5513+
dimension))
5514+
dnums = slicing.GatherDimensionNumbers(
5515+
offset_dims=(),
5516+
collapsed_slice_dims=(dimension,),
5517+
start_index_map=(dimension,),
5518+
operand_batching_dims=batch_dims,
5519+
start_indices_batching_dims=batch_dims,
5520+
)
5521+
idx = expand_dims(sorted_primals_and_idx[-1], (len(shape),))
5522+
gather_idx = partial(
5523+
slicing.gather,
5524+
start_indices=idx, dimension_numbers=dnums, slice_sizes=(1,) * len(shape),
5525+
mode=slicing.GatherScatterMode.PROMISE_IN_BOUNDS
5526+
)
5527+
tangents_out = [t if type(t) is ad_util.Zero else gather_idx(t)
5528+
for t in tangents]
55185529
return tuple(sorted_primals_and_idx[:-1]), tangents_out
55195530

55205531
def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys):

tests/lax_autodiff_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse):
834834
# TODO(b/205052657): enable more tests when supported
835835
@jtu.sample_product(
836836
[dict(shape=shape, axis=axis)
837-
for shape in [(5,), (5, 7)]
837+
for shape in [(5,), (5, 7), (4, 9, 3)]
838838
for axis in [len(shape) - 1]
839839
],
840840
dtype=[np.float32],
@@ -849,7 +849,7 @@ def testSortGrad(self, shape, dtype, axis, is_stable):
849849
# TODO(b/205052657): enable more tests when supported
850850
@jtu.sample_product(
851851
[dict(shape=shape, axis=axis)
852-
for shape in [(3,), (5, 3)]
852+
for shape in [(3,), (5, 3), (4, 9, 3)]
853853
for axis in [len(shape) - 1]
854854
],
855855
key_dtype=[np.float32],

0 commit comments

Comments
 (0)