|
54 | 54 | SingleDeviceSharding, parse_flatten_op_sharding) |
55 | 55 | from jax._src.pjit import pjit, sharding_cast |
56 | 56 | from jax._src import mesh as mesh_lib |
| 57 | +from jax._src.mesh import set_abstract_mesh, get_abstract_mesh, AxisTypes |
57 | 58 | from jax._src.interpreters import pxla |
58 | 59 | from jax._src.lib.mlir import dialects |
59 | 60 | from jax._src import xla_bridge |
@@ -5789,6 +5790,43 @@ def g(x): |
5789 | 5790 | out = jax.jit(jax.grad(g))(arr) |
5790 | 5791 | self.assertEqual(out.sharding, s) |
5791 | 5792 |
|
| 5793 | + @jtu.with_user_mesh((2,), 'x') |
| 5794 | + def test_return_output_different_context(self, mesh): |
| 5795 | + np_inp = np.arange(16).reshape(8, 2) |
| 5796 | + s = NamedSharding(mesh, P('x')) |
| 5797 | + arr = jax.device_put(np_inp, s) |
| 5798 | + |
| 5799 | + @jax.jit |
| 5800 | + def f(x): |
| 5801 | + auto_mesh = get_abstract_mesh().with_axis_types({AxisTypes.Auto: 'x'}) |
| 5802 | + with set_abstract_mesh(auto_mesh): |
| 5803 | + x = sharding_cast(x, x.sharding.with_mesh(auto_mesh)) |
| 5804 | + return x |
| 5805 | + |
| 5806 | + self.assertDictEqual(arr.sharding.mesh.axis_types, {AxisTypes.User: 'x'}) |
| 5807 | + out = f(arr) |
| 5808 | + self.assertArraysEqual(out, np_inp) |
| 5809 | + self.assertDictEqual(out.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'}) |
| 5810 | + |
| 5811 | + @jtu.with_user_mesh((2,), 'x') |
| 5812 | + def test_inputs_different_context(self, mesh): |
| 5813 | + np_inp = np.arange(16).reshape(8, 2) |
| 5814 | + s = NamedSharding(mesh, P('x')) |
| 5815 | + arr = jax.device_put(np_inp, s) |
| 5816 | + |
| 5817 | + auto_mesh = jax.make_mesh((2,), 'x', axis_types={AxisTypes.Auto: 'x'}) |
| 5818 | + with mesh_lib.set_mesh(auto_mesh): |
| 5819 | + arr2 = jnp.ones(8) |
| 5820 | + self.assertDictEqual(arr2.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'}) |
| 5821 | + |
| 5822 | + @jax.jit |
| 5823 | + def f(x, y): |
| 5824 | + return x, y |
| 5825 | + |
| 5826 | + out1, out2 = f(arr, arr2) |
| 5827 | + self.assertDictEqual(out1.sharding.mesh.axis_types, {AxisTypes.User: 'x'}) |
| 5828 | + self.assertDictEqual(out2.sharding.mesh.axis_types, {AxisTypes.Auto: 'x'}) |
| 5829 | + |
5792 | 5830 |
|
5793 | 5831 | @jtu.pytest_mark_if_available('multiaccelerator') |
5794 | 5832 | class PJitErrorTest(jtu.JaxTestCase): |
|
0 commit comments