Skip to content

Commit e854f16

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Allow P.UNCONSTRAINED in out_shardings at top level jit. This is required for sharding in types to work properly when out_avals contain UNCONSTRAINED specs.
This also simplifies the `impl` rule of `sharding_cast`. PiperOrigin-RevId: 707349491
1 parent b56dc63 commit e854f16

File tree

5 files changed

+100
-28
lines changed

5 files changed

+100
-28
lines changed

jax/_src/interpreters/mlir.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from jax._src.interpreters import xla
5050
from jax._src.layout import AutoLayout, DeviceLocalLayout
5151
from jax._src.sharding import Sharding as JSharding
52-
from jax._src.sharding_impls import AUTO
52+
from jax._src.sharding_impls import AUTO, NamedSharding
5353
from jax._src.partition_spec import UnconstrainedSingleton
5454
from jax._src.lib import xla_client as xc
5555
from jax._src.lib import xla_extension
@@ -1055,6 +1055,21 @@ def _get_mem_kind(s: JSharding | AUTO | None) -> str | None:
10551055
assert isinstance(s, JSharding)
10561056
return s.memory_kind
10571057

1058+
def contains_unconstrained(s):
1059+
return isinstance(s, NamedSharding) and None in s._parsed_pspec
1060+
1061+
def all_unconstrained(s, aval):
1062+
if isinstance(s, NamedSharding):
1063+
if aval.ndim != len(s._parsed_pspec):
1064+
return False
1065+
return all(p is None for p in s._parsed_pspec)
1066+
return False
1067+
1068+
def _get_unconstrained_dimensions(s, aval):
1069+
us = contains_unconstrained(s)
1070+
return (us, all_unconstrained(s, aval),
1071+
({i for i, p in enumerate(s._parsed_pspec) if p is None} if us else None))
1072+
10581073

10591074
def lower_jaxpr_to_module(
10601075
module_name: str,
@@ -1114,7 +1129,8 @@ def lower_jaxpr_to_module(
11141129
f"only {platforms_with_donation} support donation")
11151130
if (num_partitions > 1 and
11161131
(result_shardings is None or
1117-
all(s is None or isinstance(s, AUTO) for s in result_shardings))):
1132+
all(s is None or isinstance(s, AUTO) or contains_unconstrained(s)
1133+
for s in result_shardings))):
11181134
xla_donated_args = donated_args
11191135
donated_args = [False] * len(donated_args)
11201136
if xla_donated_args is None:
@@ -1448,7 +1464,8 @@ def lower_jaxpr_to_fun(
14481464
ir_arg_memory_kinds = None
14491465
if arg_memory_kinds is not None:
14501466
ir_arg_memory_kinds = util.flatten(
1451-
[[mk] * len_ir_types(types) for mk, types in zip(arg_memory_kinds, input_types)])
1467+
[[mk] * len_ir_types(types)
1468+
for mk, types in zip(arg_memory_kinds, input_types)])
14521469

14531470
ir_arg_layouts = None
14541471
if arg_layouts is not None:
@@ -1459,13 +1476,18 @@ def lower_jaxpr_to_fun(
14591476
ir_donated_args = None
14601477
if xla_donated_args is not None:
14611478
ir_donated_args = util.flatten(
1462-
[[is_donated] * len_ir_types(types) for is_donated, types in zip(xla_donated_args, input_types)])
1479+
[[is_donated] * len_ir_types(types)
1480+
for is_donated, types in zip(xla_donated_args, input_types)])
14631481

14641482
ir_result_shardings = None
1483+
unconstrained_shardings = None
14651484
if result_shardings is not None:
14661485
ir_result_shardings = util.flatten(
14671486
[[_to_physical_op_sharding(ctx, a, s)] * len_ir_types(types)
14681487
for a, s, types in zip(output_avals, result_shardings, output_types)])
1488+
unconstrained_shardings = util.flatten(
1489+
[[_get_unconstrained_dimensions(s, a)] * len_ir_types(types)
1490+
for a, s, types in zip(output_avals, result_shardings, output_types)])
14691491

14701492
ir_result_memory_kinds = None
14711493
custom_call_ir_result_memory_kinds = None
@@ -1580,8 +1602,9 @@ def lower_jaxpr_to_fun(
15801602
attrs['jax.result_info'] = ir.StringAttr.get(name_)
15811603

15821604
if use_sharding_annotations and ir_result_shardings is not None:
1583-
for attrs, sharding in zip(result_attrs, ir_result_shardings):
1584-
if sharding is not None:
1605+
for attrs, sharding, us in zip(result_attrs, ir_result_shardings,
1606+
unconstrained_shardings): # type: ignore
1607+
if sharding is not None and not us[0]:
15851608
if config.use_shardy_partitioner.value:
15861609
attrs["sdy.sharding"] = get_sharding_attr(sharding)
15871610
else:
@@ -1658,6 +1681,15 @@ def lower_jaxpr_to_fun(
16581681
o if s is None else wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s)
16591682
for o, s, o_aval in zip(flat_outputs, ir_result_shardings, output_avals)]
16601683

1684+
if ir_result_shardings is not None:
1685+
flat_outputs = [
1686+
wrap_with_sharding_op(entry_lowering_ctx, o, o_aval, s,
1687+
unspecified_dims=us[2])
1688+
if us[0] and not us[1] else o
1689+
for o, s, o_aval, us in zip(flat_outputs, ir_result_shardings,
1690+
output_avals, unconstrained_shardings) # type: ignore
1691+
]
1692+
16611693
# Insert a custom call if output is on host because XLA needs that to do the
16621694
# transfer.
16631695
if custom_call_ir_result_memory_kinds is not None and name == "main":

jax/_src/interpreters/pxla.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
from jax._src.lib import xla_client as xc
6363
from jax._src.lib.mlir import ir
6464
from jax._src.lib.mlir.dialects import hlo
65-
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
65+
from jax._src.partition_spec import PartitionSpec
6666
from jax._src.sharding import Sharding as JSharding
6767
from jax._src.mesh import AbstractMesh, Mesh
6868
from jax._src.sharding_impls import (
@@ -2162,10 +2162,7 @@ def _abstract_to_concrete_mesh(abstract_mesh):
21622162

21632163
out = []
21642164
for s, a in zip(shardings, avals):
2165-
# Remove the `UnconstrainedSingleton` logic after UNCONSTRAINED is supported
2166-
# in out_shardings at top level jit.
2167-
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
2168-
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
2165+
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
21692166
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
21702167
a.sharding.spec))
21712168
else:
@@ -2794,6 +2791,11 @@ def _maybe_get_and_check_out_shardings(
27942791
dtypes.issubdtype(aval.dtype, dtypes.extended)):
27952792
xla_s = sharding_impls.logical_sharding(aval, xla_s)
27962793
new_out_shardings.append(xla_s)
2794+
elif mlir.contains_unconstrained(orig):
2795+
if (aval is not core.abstract_token and
2796+
dtypes.issubdtype(aval.dtype, dtypes.extended)):
2797+
xla_s = sharding_impls.logical_sharding(aval, xla_s)
2798+
new_out_shardings.append(_gspmd_to_named_sharding(xla_s, orig)) # type: ignore
27972799
else:
27982800
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim)
27992801
orig_hlo_s = orig._to_xla_hlo_sharding(aval.ndim) # pytype: disable=attribute-error
@@ -2909,8 +2911,9 @@ def from_hlo(name: str,
29092911

29102912
allow_prop_to_inputs = tuple(isinstance(i, (UnspecifiedValue, AUTO))
29112913
for i in in_shardings)
2912-
allow_prop_to_outputs = tuple(isinstance(o, (UnspecifiedValue, AUTO))
2913-
for o in out_shardings)
2914+
allow_prop_to_outputs = tuple(
2915+
isinstance(o, (UnspecifiedValue, AUTO)) or mlir.contains_unconstrained(o)
2916+
for o in out_shardings)
29142917

29152918
mesh = None
29162919
if auto_spmd_lowering:

jax/_src/mesh.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,22 @@ def abstract_mesh(self):
353353
def with_axis_types(self, new_axis_types) -> Mesh:
354354
return Mesh(self.devices, self.axis_names, axis_types=new_axis_types)
355355

356+
@functools.cached_property
357+
def _are_all_axes_collective(self) -> bool:
358+
return all(t == AxisTypes.Collective for t in self.axis_types.keys())
359+
360+
@functools.cached_property
361+
def _are_all_axes_auto(self) -> bool:
362+
return all(t == AxisTypes.Auto for t in self.axis_types.keys())
363+
364+
@functools.cached_property
365+
def _any_axis_collective(self) -> bool:
366+
return any(t == AxisTypes.Collective for t in self.axis_types.keys())
367+
368+
@functools.cached_property
369+
def _any_axis_auto(self) -> bool:
370+
return any(t == AxisTypes.Auto for t in self.axis_types.keys())
371+
356372

357373
EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
358374

jax/_src/pjit.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2675,13 +2675,6 @@ def _sharding_constraint_batcher(
26752675

26762676
# -------------------- sharding_cast ---------------------------
26772677

2678-
def _check_mesh_shape_same(src_sharding, dst_sharding, aval):
2679-
if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple:
2680-
raise ValueError(
2681-
f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not'
2682-
' match the mesh shape of the target sharding'
2683-
f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}')
2684-
26852678
def sharding_cast(xs, shardings):
26862679
if isinstance(shardings, NamedSharding):
26872680
return tree_map(lambda x: sharding_cast_p.bind(
@@ -2695,17 +2688,17 @@ def sharding_cast(xs, shardings):
26952688

26962689
sharding_cast_p = core.Primitive('sharding_cast')
26972690
def _sharding_cast_abstract_eval(aval, src_sharding, dst_sharding):
2698-
_check_mesh_shape_same(src_sharding, dst_sharding, aval)
2691+
if src_sharding.mesh.shape_tuple != dst_sharding.mesh.shape_tuple:
2692+
raise ValueError(
2693+
f'Mesh shape of the input {src_sharding.mesh.shape_tuple} does not'
2694+
' match the mesh shape of the target sharding'
2695+
f' {dst_sharding.mesh.shape_tuple} for shape {aval.str_short()}')
26992696
return aval.update(sharding=dst_sharding)
27002697
sharding_cast_p.def_abstract_eval(_sharding_cast_abstract_eval)
27012698

27022699
def _sharding_cast_impl(x, src_sharding, dst_sharding):
2703-
aval = shaped_abstractify(x)
2704-
_check_mesh_shape_same(x.sharding, dst_sharding, aval)
2705-
new_mesh = x.sharding.mesh.with_axis_types(dst_sharding.mesh.axis_types)
2706-
concrete_dst_sharding = NamedSharding(new_mesh, dst_sharding.spec)
2707-
# TODO(yashkatariya): Replace this with `dispatch.apply_primitive(...)`
2708-
return api.jit(_identity_fn, out_shardings=concrete_dst_sharding)(x)
2700+
return dispatch.apply_primitive(sharding_cast_p, x, src_sharding=src_sharding,
2701+
dst_sharding=dst_sharding)
27092702
sharding_cast_p.def_impl(_sharding_cast_impl)
27102703

27112704
def _sharding_cast_transpose_rule(ct, _, src_sharding, dst_sharding):

tests/pjit_test.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4680,6 +4680,34 @@ def g(x, y):
46804680
RuntimeError, 'A jitted computation cannot contain AbstractMesh'):
46814681
lowered3.compile()
46824682

4683+
def test_jit_out_shardings_unconstrained(self):
4684+
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
4685+
s = NamedSharding(mesh, P('x', 'y'))
4686+
np_inp = np.arange(16).reshape(8, 2)
4687+
arr = jax.device_put(np_inp, s)
4688+
4689+
out_s = NamedSharding(mesh, P(P.UNCONSTRAINED, P.UNCONSTRAINED))
4690+
@partial(jax.jit, out_shardings=out_s)
4691+
def f(x):
4692+
return x * 2
4693+
4694+
out = f(arr)
4695+
self.assertEqual(out.sharding, s)
4696+
self.assertArraysEqual(out, np_inp * 2)
4697+
4698+
@partial(jax.jit, out_shardings=NamedSharding(mesh, P(P.UNCONSTRAINED, 'y')))
4699+
def g(x):
4700+
return x * 3
4701+
4702+
out = g(arr)
4703+
self.assertArraysEqual(out, np_inp * 3)
4704+
self.assertEqual(out.sharding, s)
4705+
lowered_text = g.lower(arr).as_text()
4706+
if config.use_shardy_partitioner.value:
4707+
self.assertIn('<@mesh, [{?}, {"y"}]>', lowered_text)
4708+
else:
4709+
self.assertIn("unspecified_dims=[0]", lowered_text)
4710+
46834711

46844712
def spec_regex(s):
46854713
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
@@ -5548,7 +5576,7 @@ def f(x, x2):
55485576
return a
55495577

55505578
out = f(arr, arr.T)
5551-
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))
5579+
self.assertEqual(out.sharding, NamedSharding(mesh, P('x',)))
55525580

55535581
def test_auto_user(self):
55545582
mesh = jtu.create_mesh((2, 2), ('x', 'y'),

0 commit comments

Comments
 (0)