Skip to content

Commit 663ef7a

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Check the type of mesh in use_abstract_mesh and use_concrete_mesh
PiperOrigin-RevId: 738190879
1 parent 3f91b4b commit 663ef7a

File tree

3 files changed

+22
-12
lines changed

3 files changed

+22
-12
lines changed

jax/_src/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from jax._src import profiler
3434
from jax._src import util
3535
from jax._src import xla_bridge
36-
from jax._src.mesh import use_concrete_mesh
3736
from jax._src.interpreters import mlir
3837
from jax._src.interpreters import pxla
3938
from jax._src.interpreters import xla
@@ -43,7 +42,8 @@
4342
from jax._src.sharding import Sharding
4443
from jax._src.sharding_impls import (
4544
PmapSharding, SingleDeviceSharding,
46-
device_replica_id_map, hashed_index, num_addressable_indices, local_to_global_shape) # pyformat: disable
45+
device_replica_id_map, hashed_index, num_addressable_indices,
46+
local_to_global_shape, use_concrete_mesh) # pyformat: disable
4747
from jax._src.typing import ArrayLike, DLDeviceType, DTypeLike
4848
from jax._src.util import safe_zip, unzip3, use_cpp_class, use_cpp_method, cache
4949
import numpy as np

jax/_src/mesh.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,10 @@ class UseAbstractMeshContextManager:
543543
__slots__ = ['mesh', 'prev']
544544

545545
def __init__(self, mesh: AbstractMesh):
546+
if not isinstance(mesh, AbstractMesh):
547+
raise ValueError(
548+
"Expected mesh of type `jax.sharding.AbstractMesh`. Got type:"
549+
f" {type(mesh)}")
546550
self.mesh = mesh
547551

548552
def __enter__(self):
@@ -557,13 +561,5 @@ def get_abstract_mesh():
557561
val = jax_config.abstract_mesh_context_manager.value
558562
return empty_abstract_mesh if val is None else val
559563

560-
@contextlib.contextmanager
561-
def use_concrete_mesh(mesh: Mesh | None):
562-
prev_val = jax_config.device_context.swap_local(mesh)
563-
try:
564-
yield
565-
finally:
566-
jax_config.device_context.set_local(prev_val)
567-
568564
def get_concrete_mesh() -> Mesh | None:
569565
return jax_config.device_context.value

jax/_src/sharding_impls.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,8 +1387,7 @@ def use_mesh(mesh: mesh_lib.Mesh):
13871387
# if not core.trace_state_clean():
13881388
# raise ValueError('`use_mesh` can only be used outside of `jax.jit`')
13891389

1390-
with (mesh_lib.use_abstract_mesh(mesh.abstract_mesh),
1391-
mesh_lib.use_concrete_mesh(mesh)):
1390+
with mesh_lib.use_abstract_mesh(mesh.abstract_mesh), use_concrete_mesh(mesh):
13921391
yield
13931392

13941393
def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None:
@@ -1408,3 +1407,18 @@ def set_mesh(mesh: mesh_lib.Mesh | None) -> mesh_lib.Mesh | None:
14081407
prev_mesh = config.device_context.get_global()
14091408
config.device_context.set_global(mesh)
14101409
return prev_mesh
1410+
1411+
@contextlib.contextmanager
1412+
def use_concrete_mesh(mesh: mesh_lib.Mesh | None):
1413+
if mesh is not None and not isinstance(mesh, mesh_lib.Mesh):
1414+
raise ValueError(
1415+
f"Expected mesh of type `jax.sharding.Mesh`. Got {type(mesh)}")
1416+
# TODO(yashkatariya): Enable this.
1417+
# if not core.trace_state_clean():
1418+
# raise ValueError('`use_concrete_mesh` can only be used outside of `jax.jit`.')
1419+
1420+
prev_val = config.device_context.swap_local(mesh)
1421+
try:
1422+
yield
1423+
finally:
1424+
config.device_context.set_local(prev_val)

0 commit comments

Comments
 (0)