Skip to content

Commit b2b3867

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Make sharding_in_types work with Shardy
PiperOrigin-RevId: 713479962
1 parent fb832af commit b2b3867

File tree

6 files changed

+116
-49
lines changed

6 files changed

+116
-49
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,8 @@
4949
from jax._src.interpreters import xla
5050
from jax._src.layout import AutoLayout, DeviceLocalLayout
5151
from jax._src.sharding import Sharding as JSharding
52-
from jax._src.sharding_impls import AUTO, NamedSharding
52+
from jax._src.sharding_impls import (AUTO, NamedSharding,
53+
modify_sdy_sharding_wrt_axis_types)
5354
from jax._src.lib import xla_client as xc
5455
from jax._src.lib import xla_extension
5556
from jax._src.lib.mlir import dialects, ir, passmanager
@@ -1689,13 +1690,17 @@ def lower_jaxpr_to_fun(
16891690
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
16901691

16911692
if ir_result_shardings is not None:
1692-
flat_outputs = [
1693-
wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s,
1694-
unspecified_dims=us[2])
1695-
if us[0] and not us[1] else o
1696-
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
1697-
output_avals, unconstrained_shardings) # type: ignore
1698-
]
1693+
temp_flat_outputs = []
1694+
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
1695+
output_avals, unconstrained_shardings): # type: ignore
1696+
if us[0] and not us[1]:
1697+
if config.use_shardy_partitioner.value and config.sharding_in_types.value:
1698+
s = modify_sdy_sharding_wrt_axis_types(s, o_aval.sharding.mesh)
1699+
temp_flat_outputs.append(wrap_with_sharding_op(
1700+
entry_lowering_ctx, o, o_aval, s, unspecified_dims=us[2]))
1701+
else:
1702+
temp_flat_outputs.append(o)
1703+
flat_outputs = temp_flat_outputs
16991704

17001705
# Insert a custom call if output is on host because XLA needs that to do the
17011706
# transfer.
@@ -2594,14 +2599,20 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
25942599
return op
25952600
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
25962601
# `return op` early and avoid bloating HLO size.
2597-
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
2598-
if sharding_proto is None else sharding_proto)
2599-
unspecified_dims = None
2600-
if aval.sharding.mesh._any_axis_collective:
2601-
unspecified_dims = set(range(aval.ndim))
2602-
elif aval.sharding.mesh._any_axis_auto:
2603-
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None}
2604-
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
2602+
if config.use_shardy_partitioner.value:
2603+
proto = (aval.sharding._to_sdy_sharding(aval.ndim)
2604+
if sharding_proto is None else sharding_proto)
2605+
proto = modify_sdy_sharding_wrt_axis_types(proto, aval.sharding.mesh)
2606+
return wrap_with_sharding_op(ctx, op, aval, proto)
2607+
else:
2608+
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
2609+
if sharding_proto is None else sharding_proto)
2610+
unspecified_dims = None
2611+
if aval.sharding.mesh._any_axis_auto:
2612+
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes
2613+
# as unspecified?
2614+
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) if s is None}
2615+
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)
26052616

26062617

26072618
def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):

jax/_src/interpreters/pxla.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2163,12 +2163,9 @@ def _abstract_to_concrete_mesh(abstract_mesh):
21632163
out = []
21642164
for s, a in zip(shardings, avals):
21652165
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
2166-
if config.use_shardy_partitioner.value:
2167-
spec = a.sharding.spec
2168-
else:
2169-
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
2170-
for sp in a.sharding.spec])
2171-
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
2166+
spec = (PartitionSpec(*[PartitionSpec.UNCONSTRAINED if sp is None else sp
2167+
for sp in a.sharding.spec])
2168+
if a.sharding.mesh._any_axis_auto else a.sharding.spec)
21722169
out.append(NamedSharding(
21732170
_abstract_to_concrete_mesh(a.sharding.mesh), spec))
21742171
else:

jax/_src/lax/lax.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2467,8 +2467,7 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
24672467
if in_aval.sharding == out_aval.sharding or in_aval.sharding is None:
24682468
out.append(op)
24692469
else:
2470-
proto = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
2471-
out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval, proto))
2470+
out.append(mlir.lower_sharding_under_shit(ctx, op, out_aval))
24722471
return out
24732472

24742473

jax/_src/pjit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2710,7 +2710,9 @@ def _sharding_cast_transpose_rule(ct, _, src_sharding, dst_sharding):
27102710
def _sharding_cast_hlo_lowering(ctx, x_node, *, src_sharding, dst_sharding):
27112711
aval, = ctx.avals_in
27122712
aval_out, = ctx.avals_out
2713-
proto = dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
2713+
proto = (dst_sharding._to_sdy_sharding(aval.ndim)
2714+
if config.use_shardy_partitioner.value else
2715+
dst_sharding._to_xla_hlo_sharding(aval.ndim).to_proto())
27142716
return [mlir.lower_sharding_under_shit(ctx, x_node, aval_out, proto)]
27152717
mlir.register_lowering(sharding_cast_p, _sharding_cast_hlo_lowering)
27162718

jax/_src/sharding_impls.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ class SdyArraySharding:
142142
mesh_shape: tuple[tuple[str, int], ...] | None
143143
dimension_shardings: Sequence[SdyDimSharding]
144144
logical_device_ids: tuple[int, ...] | None = None
145+
replicated_axes: tuple[str, ...] = ()
145146

146147
# NOTE: An MLIR context is required as a context manager.
147148
def build(self) -> sdy.TensorShardingAttr:
@@ -155,14 +156,17 @@ def build(self) -> sdy.TensorShardingAttr:
155156
ldi)
156157
return sdy.TensorShardingAttr.get(
157158
mesh_attr,
158-
[dim_sharding.build() for dim_sharding in self.dimension_shardings])
159+
[dim_sharding.build() for dim_sharding in self.dimension_shardings],
160+
replicated_axes=[sdy.AxisRefAttr.get(axis) for axis in self.replicated_axes])
159161

160162
def __repr__(self):
161163
dim_sharding_repr = ', '.join(
162164
d._custom_repr() for d in self.dimension_shardings)
163165
device_id_repr = (f', device_ids={self.logical_device_ids}'
164166
if self.logical_device_ids is not None else '')
165-
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr})"
167+
rar = (f', replicated_axes={self.replicated_axes}'
168+
if self.replicated_axes else '')
169+
return f"SdyArraySharding([{dim_sharding_repr}]{device_id_repr}{rar})"
166170

167171

168172
@util.cache(max_size=4096, trace_context_in_key=False)
@@ -425,6 +429,23 @@ def _to_sdy_sharding(self, num_dimensions: int) -> SdyArraySharding:
425429
return SdyArraySharding(self.mesh.shape_tuple, dim_shardings,
426430
self._logical_device_ids)
427431

432+
# TODO(yashkatariya): Upstream this into `_to_sdy_sharding` maybe with an extra
433+
# parameter to it `_to_sdy_sharding(self, ndim, modify_wrt_axis_types=False)`
434+
def modify_sdy_sharding_wrt_axis_types(sdy_sharding: SdyArraySharding, mesh):
435+
if mesh._any_axis_auto:
436+
dim_shardings, used_axes = [], [] # type: ignore
437+
for d in sdy_sharding.dimension_shardings:
438+
# TODO(yashkatariya): Maybe if any mesh axis is auto, mark all axes as open?
439+
dim_shardings.append(SdyDimSharding(axes=[], is_closed=False)
440+
if not d.axes and d.is_closed else d)
441+
used_axes.extend(d.axes)
442+
remaining_axes = set(mesh.axis_names) - set(used_axes)
443+
replicated_axes = tuple(r for r in remaining_axes
444+
if mesh._name_to_type[r] == mesh_lib.AxisTypes.User)
445+
return SdyArraySharding(sdy_sharding.mesh_shape, dim_shardings,
446+
sdy_sharding.logical_device_ids, replicated_axes)
447+
return sdy_sharding
448+
428449

429450
@util.cache(max_size=128, trace_context_in_key=False)
430451
def get_replicated_hlo_sharding():

tests/pjit_test.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4729,9 +4729,14 @@ def spec_regex(s):
47294729
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
47304730

47314731

4732-
@jtu.with_config(jax_use_shardy_partitioner=False)
47334732
class ShardingInTypesTest(jtu.JaxTestCase):
47344733

4734+
def check_wsc_in_lowered(self, text):
4735+
if config.use_shardy_partitioner.value:
4736+
self.assertIn('sdy.sharding_constraint', text)
4737+
else:
4738+
self.assertIn('@Sharding', text)
4739+
47354740
@jtu.with_user_mesh((2, 2), ('x', 'y'))
47364741
def test_basic_mul(self, mesh):
47374742
np_inp = np.arange(16.).reshape(8, 2)
@@ -4753,7 +4758,7 @@ def f(x):
47534758

47544759
lowered_text = f.lower(arr).as_text()
47554760
if config.use_shardy_partitioner.value:
4756-
self.assertIn('sdy.sharding_constraint', lowered_text)
4761+
self.assertEqual(lowered_text.count('sdy.sharding_constraint'), 3)
47574762
else:
47584763
self.assertEqual(lowered_text.count('@Sharding'), 3)
47594764

@@ -4834,7 +4839,7 @@ def f(x, y):
48344839
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
48354840

48364841
lowered = f.lower(arr1, arr2)
4837-
self.assertIn('@Sharding', lowered.as_text())
4842+
self.check_wsc_in_lowered(lowered.as_text())
48384843

48394844
compiled_text = lowered.compile().as_text()
48404845
if collective_name is not None and compiled_text is not None:
@@ -4971,7 +4976,7 @@ def f(x):
49714976
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
49724977

49734978
lowered = f.lower(arr)
4974-
self.assertIn('@Sharding', lowered.as_text())
4979+
self.check_wsc_in_lowered(lowered.as_text())
49754980

49764981
compiled_text = lowered.compile().as_text()
49774982
if reduce and compiled_text is not None:
@@ -5002,7 +5007,7 @@ def f(x):
50025007
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
50035008

50045009
lowered = f.lower(arr)
5005-
self.assertIn('@Sharding', lowered.as_text())
5010+
self.check_wsc_in_lowered(lowered.as_text())
50065011

50075012
compiled_text = lowered.compile().as_text()
50085013
if reduce and compiled_text is not None:
@@ -5044,7 +5049,7 @@ def f(x):
50445049
self.assertEqual(out.sharding, NamedSharding(mesh, out_spec))
50455050

50465051
lowered_text = f.lower(arr).as_text()
5047-
self.assertIn('@Sharding', lowered_text)
5052+
self.check_wsc_in_lowered(lowered_text)
50485053

50495054
@parameterized.named_parameters(
50505055
('2', 2),
@@ -5068,7 +5073,7 @@ def f(x):
50685073
self.assertArraysEqual(out, np_inp ** pow)
50695074

50705075
lowered_text = f.lower(arr).as_text()
5071-
self.assertIn('@Sharding', lowered_text)
5076+
self.check_wsc_in_lowered(lowered_text)
50725077

50735078
@jtu.with_user_mesh((1,), 'x')
50745079
def test_broadcasting_nary_error(self, mesh):
@@ -5102,7 +5107,7 @@ def f(x):
51025107
self.assertEqual(out.sharding, s)
51035108

51045109
lowered_text = f.lower(arr).as_text()
5105-
self.assertIn('@Sharding', lowered_text)
5110+
self.check_wsc_in_lowered(lowered_text)
51065111

51075112
@jtu.with_user_mesh((2, 2), ('x', 'y'))
51085113
def test_jnp_array(self, mesh):
@@ -5137,7 +5142,7 @@ def f(x):
51375142
self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'z', 'x')))
51385143

51395144
lowered_text = f.lower(arr).as_text()
5140-
self.assertIn('@Sharding', lowered_text)
5145+
self.check_wsc_in_lowered(lowered_text)
51415146

51425147
@jtu.with_user_mesh((2, 2), ('x', 'y'))
51435148
def test_broadcasted_iota_with_sharding(self, mesh):
@@ -5182,7 +5187,7 @@ def f(x, y):
51825187
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
51835188

51845189
lowered_text = f.lower(arr1, arr2).as_text()
5185-
self.assertIn('@Sharding', lowered_text)
5190+
self.check_wsc_in_lowered(lowered_text)
51865191

51875192
@jax.jit
51885193
def g(x, y):
@@ -5228,7 +5233,7 @@ def h(x, y):
52285233
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None, 'y', None)))
52295234

52305235
lowered_text = h.lower(arr1, arr2).as_text()
5231-
self.assertIn('@Sharding', lowered_text)
5236+
self.check_wsc_in_lowered(lowered_text)
52325237

52335238
@jax.jit
52345239
def h2(x, y):
@@ -5268,7 +5273,7 @@ def f(x, new_sharding):
52685273
self.assertArraysEqual(out, np_inp.reshape(dst_shape) * 2)
52695274

52705275
lowered_text = f.lower(arr, new_s).as_text()
5271-
self.assertIn('@Sharding', lowered_text)
5276+
self.check_wsc_in_lowered(lowered_text)
52725277

52735278
def g(x):
52745279
out = f(x, new_s)
@@ -5295,7 +5300,7 @@ def f(pred, on_true, on_false):
52955300
self.assertArraysEqual(out, arr1)
52965301

52975302
lowered_text = f.lower(arr1 == arr2, arr1, arr2).as_text()
5298-
self.assertIn('@Sharding', lowered_text)
5303+
self.check_wsc_in_lowered(lowered_text)
52995304

53005305
arr3 = jax.device_put(np_inp, NamedSharding(mesh, P('y', 'x')))
53015306
with self.assertRaisesRegex(
@@ -5383,7 +5388,7 @@ def f(x):
53835388

53845389
out = f(arr)
53855390
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
5386-
self.assertIn('@Sharding', f.lower(arr).as_text())
5391+
self.check_wsc_in_lowered(f.lower(arr).as_text())
53875392

53885393
def g(x):
53895394
out = f(x)
@@ -5414,7 +5419,7 @@ def f(x):
54145419

54155420
out = f(arr)
54165421
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
5417-
self.assertIn('@Sharding', f.lower(arr).as_text())
5422+
self.check_wsc_in_lowered(f.lower(arr).as_text())
54185423
self.assertArraysEqual(out, np.squeeze(np_inp, axis=2))
54195424

54205425
def g(x):
@@ -5441,7 +5446,7 @@ def f(x, padding_config, spec):
54415446
out = f(arr, ((2, 2, 0),), P('x'))
54425447
self.assertArraysEqual(out, np.pad(np_inp, 2))
54435448
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
5444-
self.assertIn('@Sharding', f.lower(arr, ((2, 2, 0),), P('x')).as_text())
5449+
self.check_wsc_in_lowered(f.lower(arr, ((2, 2, 0),), P('x')).as_text())
54455450

54465451
out = f(arr, ((0, 0, 0),), P('x'))
54475452
self.assertArraysEqual(out, np_inp)
@@ -5489,7 +5494,7 @@ def f(x, y, method='jnp'):
54895494
out = f(arr1, arr2)
54905495
self.assertEqual(out.sharding, s)
54915496
self.assertArraysEqual(out, np.concatenate([arr1, arr2], axis=1))
5492-
self.assertIn('@Sharding', f.lower(arr1, arr2).as_text())
5497+
self.check_wsc_in_lowered(f.lower(arr1, arr2).as_text())
54935498

54945499
out = f(arr1, arr2, method='lax')
54955500
self.assertEqual(out.sharding, s)
@@ -5568,8 +5573,7 @@ def f(x):
55685573
self.assertEqual(out1.sharding, NamedSharding(mesh, P('y')))
55695574
self.assertArraysEqual(out2, np.argmin(np_inp, axis=1))
55705575
self.assertEqual(out2.sharding, NamedSharding(mesh, P('x')))
5571-
5572-
self.assertIn('@Sharding', f.lower(arr).as_text())
5576+
self.check_wsc_in_lowered(f.lower(arr).as_text())
55735577

55745578
@jtu.with_user_mesh((2, 2), ('x', 'y'), {mesh_lib.AxisTypes.Auto: ('x', 'y')})
55755579
def test_only_auto(self, mesh):
@@ -5618,7 +5622,10 @@ def f(x, x2):
56185622
out = f(arr, arr2)
56195623
self.assertEqual(out.sharding, NamedSharding(mesh2, P('x',)))
56205624
lowered_text = f.lower(arr, arr2).as_text()
5621-
self.assertTrue(lowered_text.count("unspecified_dims") == 5)
5625+
if config.use_shardy_partitioner.value:
5626+
self.assertTrue(lowered_text.count("{?}") == 5)
5627+
else:
5628+
self.assertTrue(lowered_text.count("unspecified_dims") == 5)
56225629

56235630
mesh3 = jtu.create_mesh((2, 2), ('x', 'y'),
56245631
axis_types={mesh_lib.AxisTypes.User: 'y',
@@ -5629,7 +5636,12 @@ def f(x, x2):
56295636
out = f(arr, arr2)
56305637
self.assertEqual(out.sharding, NamedSharding(mesh3, P('x',)))
56315638
lowered_text = f.lower(arr, arr2).as_text()
5632-
self.assertTrue(lowered_text.count("unspecified_dims") == 4)
5639+
print(lowered_text)
5640+
if config.use_shardy_partitioner.value:
5641+
self.assertTrue(lowered_text.count("{?}") == 5)
5642+
self.assertIn('replicated={"y"}', lowered_text)
5643+
else:
5644+
self.assertTrue(lowered_text.count("unspecified_dims") == 4)
56335645

56345646
with self.assertRaisesRegex(
56355647
ValueError,
@@ -5784,7 +5796,7 @@ def f(x, sizes=(4, 4), axis=0):
57845796
return ys
57855797

57865798
f(arr)
5787-
self.assertIn('@Sharding', f.lower(arr).as_text())
5799+
self.check_wsc_in_lowered(f.lower(arr).as_text())
57885800

57895801
with self.assertRaisesRegex(NotImplementedError, "split on sharded dims"):
57905802
f(arr, sizes=(1, 1), axis=1)
@@ -5864,6 +5876,31 @@ def g(x, y):
58645876
ValueError, "PartitionSpec cannot contain axis names.*Auto"):
58655877
g(arr1, arr2)
58665878

5879+
@jtu.with_user_mesh((2, 2, 2), ('x', 'y', 'z'),
5880+
axis_types={AxisTypes.User: ('x', 'y'),
5881+
AxisTypes.Auto: 'z'})
5882+
def test_out_sharding_mix_axis_types(self, mesh):
5883+
np_inp = np.arange(16).reshape(4, 2, 2)
5884+
s = NamedSharding(mesh, P('x', None, None))
5885+
arr = jax.device_put(np_inp, s)
5886+
5887+
@jax.jit
5888+
def f(x):
5889+
y = x * 2
5890+
self.assertEqual(y.sharding.spec, P('x', None, None))
5891+
return y
5892+
5893+
out = f(arr)
5894+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
5895+
self.assertArraysEqual(out, np_inp * 2)
5896+
5897+
lowered_text = f.lower(arr).as_text()
5898+
if config.use_shardy_partitioner.value:
5899+
self.assertTrue(lowered_text.count(
5900+
'[{"x"}, {?}, {?}], replicated={"y"}') == 3)
5901+
else:
5902+
self.assertTrue(lowered_text.count("unspecified_dims=[1,2]") == 3)
5903+
58675904

58685905
@jtu.pytest_mark_if_available('multiaccelerator')
58695906
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)