Skip to content

Commit 21f8885

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Make argmax and argmin work with sharding_in_types. This also requires adding reduce_p sharding rule
PiperOrigin-RevId: 699244204
1 parent 7635605 commit 21f8885

File tree

3 files changed

+59
-15
lines changed

3 files changed

+59
-15
lines changed

jax/_src/lax/lax.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5007,6 +5007,11 @@ def _reduce_shape_rule(*avals, computation, jaxpr, dimensions):
50075007
raise ValueError(f'reduce found non-scalar initial value: {init_val_shapes}')
50085008
return [tuple(np.delete(op.shape, dimensions)) for op in operand_avals]
50095009

5010+
def _reduce_sharding_rule(*avals, computation, jaxpr, dimensions):
5011+
operand_avals, _ = split_list(avals, [len(avals) // 2])
5012+
return [op.sharding.with_spec(tuple_delete(op.sharding.spec, dimensions))
5013+
for op in operand_avals]
5014+
50105015
def _reduce_dtype_rule(*avals, computation, jaxpr, dimensions):
50115016
operand_avals, init_val_avals = split_list(avals, [len(avals) // 2])
50125017
operand_dtypes = [dtypes.canonicalize_dtype(op.dtype) for op in operand_avals]
@@ -5093,7 +5098,7 @@ def _reduce_jvp_rule(primals, tangents, *, computation, jaxpr, dimensions):
50935098
reduce_p.def_impl(partial(dispatch.apply_primitive, reduce_p))
50945099
reduce_p.def_abstract_eval(
50955100
partial(standard_multi_result_abstract_eval, reduce_p, _reduce_shape_rule,
5096-
_reduce_dtype_rule, _reduce_weak_type_rule))
5101+
_reduce_dtype_rule, _reduce_weak_type_rule, _reduce_sharding_rule))
50975102
batching.primitive_batchers[reduce_p] = _reduce_batch_rule
50985103
ad.primitive_jvps[reduce_p] = _reduce_jvp_rule
50995104

@@ -5115,6 +5120,9 @@ def _reduce_lower(ctx, *values, computation, jaxpr, dimensions):
51155120
*reducer.arguments,
51165121
dim_var_values=ctx.dim_var_values)
51175122
hlo.return_(mlir.flatten_ir_values(out_nodes))
5123+
if config.sharding_in_types.value:
5124+
return [mlir.lower_sharding_under_shit(ctx, r, aval)
5125+
for r, aval in safe_zip(op.results, ctx.avals_out)]
51185126
return op.results
51195127

51205128
mlir.register_lowering(reduce_p, _reduce_lower)
@@ -5227,7 +5235,12 @@ def _argminmax_shape_rule(operand, *, axes, index_dtype):
52275235
if operand.shape[axis] < 1:
52285236
raise ValueError("argmin and argmax require non-empty reduced dimension. "
52295237
f"operand.shape={operand.shape} {axis=}")
5230-
return tuple(np.delete(operand.shape, axis))
5238+
return util.tuple_delete(operand.shape, axis)
5239+
5240+
def _argminmax_sharding_rule(operand, *, axes, index_dtype):
5241+
axis, = axes
5242+
return operand.sharding.with_spec(
5243+
util.tuple_delete(operand.sharding.spec, axis))
52315244

52325245
def _argminmax_dtype_rule(operand, *, axes, index_dtype):
52335246
if not dtypes.issubdtype(index_dtype, np.integer):
@@ -5264,30 +5277,34 @@ def _compute_argminmax(value_comparator, get_identity,
52645277
# value_comparator is either lax.lt (for argmin) or lax.gt
52655278
# get_identity(operand.dtype) is inf for argmin or -inf for argmax
52665279
axis, = axes
5267-
indices = broadcasted_iota(index_dtype, np.shape(operand), axis)
5280+
indices = broadcasted_iota(
5281+
index_dtype, np.shape(operand), axis,
5282+
_sharding=operand.sharding if config.sharding_in_types.value else None)
52685283
res = reduce([operand, indices],
52695284
[get_identity(operand.dtype), np.array(0, index_dtype)],
52705285
_ArgMinMaxReducer(value_comparator),
52715286
axes)
52725287
return res[1]
52735288

52745289
argmin_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
5275-
'argmin', weak_type_rule=_strip_weak_type)
5290+
'argmin', weak_type_rule=_strip_weak_type,
5291+
sharding_rule=_argminmax_sharding_rule)
52765292
batching.defreducer(argmin_p, _get_min_identity)
52775293
ad.defjvp_zero(argmin_p)
52785294

52795295
argmax_p = standard_primitive(_argminmax_shape_rule, _argminmax_dtype_rule,
5280-
'argmax', weak_type_rule=_strip_weak_type)
5296+
'argmax', weak_type_rule=_strip_weak_type,
5297+
sharding_rule=_argminmax_sharding_rule)
52815298
batching.defreducer(argmax_p, _get_max_identity)
52825299
ad.defjvp_zero(argmax_p)
52835300

5284-
mlir.register_lowering(argmin_p, mlir.cache_lowering(mlir.lower_fun(
5285-
partial(_compute_argminmax, lt, _get_min_identity),
5286-
multiple_results=False)))
5301+
mlir.register_lowering(argmin_p, mlir.cache_lowering(
5302+
mlir.lower_fun(partial(_compute_argminmax, lt, _get_min_identity),
5303+
multiple_results=False)))
52875304

5288-
mlir.register_lowering(argmax_p, mlir.cache_lowering(mlir.lower_fun(
5289-
partial(_compute_argminmax, gt, _get_max_identity),
5290-
multiple_results=False)))
5305+
mlir.register_lowering(argmax_p, mlir.cache_lowering(
5306+
mlir.lower_fun(partial(_compute_argminmax, gt, _get_max_identity),
5307+
multiple_results=False)))
52915308

52925309

52935310
def _reduce_logical_shape_rule(operand, *, axes):
@@ -5882,7 +5899,7 @@ def _rng_bit_generator_lowering(
58825899
rng_bit_generator_p.def_abstract_eval(
58835900
partial(standard_multi_result_abstract_eval, rng_bit_generator_p,
58845901
_rng_bit_generator_shape_rule, _rng_bit_generator_dtype_rule,
5885-
_rng_bit_generator_weak_type_rule))
5902+
_rng_bit_generator_weak_type_rule, None))
58865903
mlir.register_lowering(rng_bit_generator_p,
58875904
_rng_bit_generator_lowering)
58885905

jax/_src/lax/utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,16 +69,21 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
6969
raise TypeError(avals, least_specialized)
7070

7171
def standard_multi_result_abstract_eval(
72-
prim, shape_rule, dtype_rule, weak_type_rule, *avals, **kwargs):
72+
prim, shape_rule, dtype_rule, weak_type_rule, sharding_rule,
73+
*avals, **kwargs):
7374
assert prim.multiple_results
7475
assert all(isinstance(aval, core.UnshapedArray) for aval in avals), avals
7576
least_specialized = max(map(type, avals), key=_get_array_abstraction_level)
7677
weak_types = weak_type_rule(*avals, **kwargs)
7778
if least_specialized is core.ShapedArray:
7879
out_shapes = shape_rule(*avals, **kwargs)
7980
out_dtypes = dtype_rule(*avals, **kwargs)
80-
return [core.ShapedArray(s, d, weak_type=weak_type)
81-
for s, d, weak_type in zip(out_shapes, out_dtypes, weak_types)]
81+
out_shardings = (sharding_rule(*avals, **kwargs)
82+
if config.sharding_in_types.value else
83+
[None] * len(out_shapes))
84+
return [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh)
85+
for s, d, weak_type, sh in zip(out_shapes, out_dtypes, weak_types,
86+
out_shardings)]
8287
elif least_specialized is core.UnshapedArray:
8388
out_dtypes = dtype_rule(*avals, **kwargs)
8489
return [core.UnshapedArray(dtype, weak_type=weak_type)

tests/pjit_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5466,6 +5466,28 @@ def g(carry, arr):
54665466
ValueError, "0th dimension of all xs should be replicated"):
54675467
f(carry, jax.device_put(arr, NamedSharding(mesh, P('x', None, None))))
54685468

5469+
def test_argminmax(self):
5470+
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
5471+
np_inp = np.arange(16.).reshape(8, 2)
5472+
s = NamedSharding(mesh, P('x', 'y'))
5473+
arr = jax.device_put(np_inp, s)
5474+
5475+
@jax.jit
5476+
def f(x):
5477+
z = jnp.argmax(x, axis=0)
5478+
self.assertEqual(z.sharding.spec, P('y'))
5479+
a = jnp.argmin(x, axis=1)
5480+
self.assertEqual(a.sharding.spec, P('x'))
5481+
return z, a
5482+
5483+
out1, out2 = f(arr)
5484+
self.assertArraysEqual(out1, np.argmax(np_inp, axis=0))
5485+
self.assertEqual(out1.sharding, NamedSharding(mesh, P('y')))
5486+
self.assertArraysEqual(out2, np.argmin(np_inp, axis=1))
5487+
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
5488+
5489+
self.assertIn('@Sharding', f.lower(arr).as_text())
5490+
54695491

54705492
@jtu.pytest_mark_if_available('multiaccelerator')
54715493
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)