Skip to content

Commit a5c0f20

Browse files
yashk2810Google-ML-Automation
authored andcommitted
set_mesh should return the prev_mesh instead of nothing. Users can choose to use the return value or ignore it.
PiperOrigin-RevId: 738039559
1 parent 7c5871f commit a5c0f20

File tree

2 files changed

+14
-10
lines changed

2 files changed

+14
-10
lines changed

jax/_src/sharding_impls.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1391,12 +1391,20 @@ def use_mesh(mesh: mesh_lib.Mesh):
13911391
mesh_lib.use_concrete_mesh(mesh)):
13921392
yield
13931393

1394-
def set_mesh(mesh: mesh_lib.Mesh) -> None:
1395-
if not isinstance(mesh, mesh_lib.Mesh):
1394+
def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None:
1395+
"""Sets the given concrete mesh globally and returns the previous concrete
1396+
mesh."""
1397+
if mesh is not None and not isinstance(mesh, mesh_lib.Mesh):
13961398
raise ValueError(
13971399
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
13981400
if not core.trace_state_clean():
13991401
raise ValueError('`set_mesh` can only be used outside of `jax.jit`.')
14001402

1401-
config.abstract_mesh_context_manager.set_local(mesh.abstract_mesh)
1402-
config.device_context.set_local(mesh)
1403+
if mesh is None:
1404+
config.abstract_mesh_context_manager.set_global(mesh_lib.empty_abstract_mesh) # type: ignore
1405+
else:
1406+
config.abstract_mesh_context_manager.set_global(mesh.abstract_mesh) # type: ignore
1407+
1408+
prev_mesh = config.device_context.get_global()
1409+
config.device_context.set_global(mesh)
1410+
return prev_mesh

tests/pjit_test.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7096,16 +7096,12 @@ def f(x):
70967096

70977097
def test_set_mesh(self):
70987098
mesh = jtu.create_mesh((2,), ('x',), axis_types=(AxisType.Explicit,))
7099-
prev_mesh = config.device_context.value
7100-
prev_abstract_mesh = config.abstract_mesh_context_manager.value
71017099
try:
7102-
jax.sharding.set_mesh(mesh)
7103-
7100+
prev_mesh = jax.sharding.set_mesh(mesh)
71047101
out = reshard(np.arange(8), P('x'))
71057102
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
71067103
finally:
7107-
config.device_context.set_local(prev_mesh)
7108-
config.abstract_mesh_context_manager.set_local(prev_abstract_mesh)
7104+
jax.sharding.set_mesh(prev_mesh)
71097105

71107106
@jtu.with_user_mesh((2,), ('x',))
71117107
def test_auto_axes_late_bind(self, mesh):

0 commit comments

Comments
 (0)