Skip to content

Commit af63e44

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Check out_avals with mesh context too. This is because users can pass their own shardings to functions like einsum, reshape, broadcast`, etc
PiperOrigin-RevId: 707672801
1 parent 13e721a commit af63e44

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

jax/_src/lax/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,13 @@ def standard_abstract_eval(prim, shape_rule, dtype_rule, weak_type_rule,
5454
least_specialized = type(max(avals, key=_get_array_abstraction_level))
5555
if least_specialized is core.ShapedArray:
5656
core.check_avals_context_mesh(avals, prim.name)
57-
return core.ShapedArray(
57+
out_aval = core.ShapedArray(
5858
shape_rule(*avals, **kwargs), dtype_rule(*avals, **kwargs),
5959
weak_type=weak_type,
6060
sharding=(sharding_rule(*avals, **kwargs)
6161
if config.sharding_in_types.value else None))
62+
core.check_avals_context_mesh([out_aval], prim.name)
63+
return out_aval
6264
elif least_specialized is core.DShapedArray:
6365
shape = shape_rule(*avals, **kwargs)
6466
ty = (core.ShapedArray if all(type(d) is int for d in shape)
@@ -83,9 +85,11 @@ def standard_multi_result_abstract_eval(
8385
out_shardings = (sharding_rule(*avals, **kwargs)
8486
if config.sharding_in_types.value else
8587
[None] * len(out_shapes))
86-
return [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh)
87-
for s, d, weak_type, sh in zip(out_shapes, out_dtypes, weak_types,
88-
out_shardings)]
88+
out_avals = [core.ShapedArray(s, d, weak_type=weak_type, sharding=sh)
89+
for s, d, weak_type, sh in zip(out_shapes, out_dtypes,
90+
weak_types, out_shardings)]
91+
core.check_avals_context_mesh(out_avals, prim.name)
92+
return out_avals
8993
elif least_specialized is core.UnshapedArray:
9094
out_dtypes = dtype_rule(*avals, **kwargs)
9195
return [core.UnshapedArray(dtype, weak_type=weak_type)

tests/pjit_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5827,6 +5827,24 @@ def f(x, y):
58275827
self.assertDictEqual(out1.sharding.mesh.axis_types, {AxisTypes.User: 'x'})
58285828
self.assertDictEqual(out2.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'})
58295829

5830+
@jtu.with_user_mesh((2,), 'x')
5831+
def test_output_different_context_error(self, mesh):
5832+
np_inp1 = np.arange(16).reshape(8, 2)
5833+
arr1 = jax.device_put(np_inp1, NamedSharding(mesh, P('x', None)))
5834+
arr2 = jax.device_put(np_inp1.T, NamedSharding(mesh, P(None, 'x')))
5835+
auto_mesh = jax.make_mesh((2,), 'x',
5836+
axis_types={AxisTypes.Auto: 'x'}).abstract_mesh
5837+
5838+
@jax.jit
5839+
def f(x, y):
5840+
out = jnp.einsum('xy,yz->xz', x, y,
5841+
out_type=NamedSharding(auto_mesh, P('x', None)))
5842+
return out
5843+
5844+
with self.assertRaisesRegex(
5845+
ValueError, "context mesh.* should match the aval mesh"):
5846+
f(arr1, arr2)
5847+
58305848

58315849
@jtu.pytest_mark_if_available('multiaccelerator')
58325850
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)