Skip to content

Commit 09fdd0d

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Add tests allowing inputs and outputs of jit to have different axis_types on their mesh than the axis_types on the surrounding mesh context
PiperOrigin-RevId: 707356052
1 parent e854f16 commit 09fdd0d

File tree

1 file changed

+38
-0
lines changed

1 file changed

+38
-0
lines changed

tests/pjit_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
SingleDeviceSharding, parse_flatten_op_sharding)
5555
from jax._src.pjit import pjit, sharding_cast
5656
from jax._src import mesh as mesh_lib
57+
from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes
5758
from jax._src.interpreters import pxla
5859
from jax._src.lib.mlir import dialects
5960
from jax._src import xla_bridge
@@ -5789,6 +5790,43 @@ def g(x):
57895790
out = jax.jit(jax.grad(g))(arr)
57905791
self.assertEqual(out.sharding, s)
57915792

5793+
@jtu.with_user_mesh((2,), 'x')
5794+
def test_return_output_different_context(self, mesh):
5795+
np_inp = np.arange(16).reshape(8, 2)
5796+
s = NamedSharding(mesh, P('x'))
5797+
arr = jax.device_put(np_inp, s)
5798+
5799+
@jax.jit
5800+
def f(x):
5801+
auto_mesh = get_abstract_mesh().with_axis_types({AxisTypes.Auto: 'x'})
5802+
with set_abstract_mesh(auto_mesh):
5803+
x = sharding_cast(x, x.sharding.with_mesh(auto_mesh))
5804+
return x
5805+
5806+
self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.User: 'x'})
5807+
out = f(arr)
5808+
self.assertArraysEqual(out, np_inp)
5809+
self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'})
5810+
5811+
@jtu.with_user_mesh((2,), 'x')
5812+
def test_inputs_different_context(self, mesh):
5813+
np_inp = np.arange(16).reshape(8, 2)
5814+
s = NamedSharding(mesh, P('x'))
5815+
arr = jax.device_put(np_inp, s)
5816+
5817+
auto_mesh = jax.make_mesh((2,), 'x', axis_types={AxisTypes.Auto: 'x'})
5818+
with mesh_lib.set_mesh(auto_mesh):
5819+
arr2 = jnp.ones(8)
5820+
self.assertDictEqual(arr2.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'})
5821+
5822+
@jax.jit
5823+
def f(x, y):
5824+
return x, y
5825+
5826+
out1, out2 = f(arr, arr2)
5827+
self.assertDictEqual(out1.sharding.mesh.axis_types, {AxisTypes.User: 'x'})
5828+
self.assertDictEqual(out2.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'})
5829+
57925830

57935831
@jtu.pytest_mark_if_available('multiaccelerator')
57945832
class PJitErrorTest(jtu.JaxTestCase):

0 commit comments

Comments
 (0)