Skip to content

Commit 40024b1

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Expose jax.sharding.get_mesh() as a way to get the concrete mesh. This is useful for writing libraries. get_mesh() can't be used inside jax.jit just like set_mesh.
PiperOrigin-RevId: 833451082
1 parent f8f68d6 commit 40024b1

File tree

4 files changed

+12
-1
lines changed

4 files changed

+12
-1
lines changed

jax/_src/sharding_impls.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,6 +1258,14 @@ def __exit__(self, exc_type, exc_value, traceback):
12581258
config.device_context.set_local(self.prev_mesh)
12591259

12601260

1261+
def get_mesh() -> mesh_lib.Mesh:
1262+
if not core.trace_state_clean():
1263+
raise ValueError(
1264+
'`get_mesh` can only be used outside of `jax.jit`. Maybe you want'
1265+
' `jax.sharding.get_abstract_mesh()`?')
1266+
return mesh_lib.get_concrete_mesh()
1267+
1268+
12611269
@contextlib.contextmanager
12621270
def _internal_use_concrete_mesh(mesh: mesh_lib.Mesh):
12631271
assert isinstance(mesh, mesh_lib.Mesh)

jax/sharding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
SingleDeviceSharding as SingleDeviceSharding,
2222
PmapSharding as _deprecated_PmapSharding,
2323
set_mesh as set_mesh,
24+
get_mesh as get_mesh,
2425
)
2526
from jax._src.partition_spec import (
2627
PartitionSpec as PartitionSpec,

tests/documentation_coverage_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def jax_docs_dir() -> str:
7171
'jax.profiler': ['ProfileData', 'ProfileEvent', 'ProfileOptions', 'ProfilePlane', 'stop_server'],
7272
'jax.random': ['key_impl', 'random_gamma_p'],
7373
'jax.scipy.special': ['bessel_jn', 'sph_harm_y'],
74-
'jax.sharding': ['AbstractDevice', 'AbstractMesh', 'AxisType', 'auto_axes', 'explicit_axes', 'get_abstract_mesh', 'reshard', 'set_mesh', 'use_abstract_mesh'],
74+
'jax.sharding': ['AbstractDevice', 'AbstractMesh', 'AxisType', 'auto_axes', 'explicit_axes', 'get_abstract_mesh', 'reshard', 'set_mesh', 'use_abstract_mesh', 'get_mesh'],
7575
'jax.stages': ['ArgInfo', 'CompilerOptions'],
7676
'jax.tree_util': ['DictKey', 'FlattenedIndexKey', 'GetAttrKey', 'PyTreeDef', 'SequenceKey', 'default_registry'],
7777
}

tests/pjit_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7633,6 +7633,8 @@ def test_set_mesh(self):
76337633
jax.set_mesh(mesh)
76347634
out = reshard(np.arange(8), P('x'))
76357635
self.assertEqual(out.sharding, NamedSharding(mesh, P('x')))
7636+
out_mesh = jax.sharding.get_mesh()
7637+
self.assertEqual(out_mesh, mesh)
76367638
finally:
76377639
config.abstract_mesh_context_manager.set_local(
76387640
mesh_lib.empty_abstract_mesh)

0 commit comments

Comments
 (0)