Skip to content

Commit 9e2708e

Browse files
yashk2810Google-ML-Automation
authored andcommitted
[sharding_in_types] Use set_mesh API to trigger sharding_in_types instead of the config option.
PiperOrigin-RevId: 702814257
1 parent fa6585d commit 9e2708e

File tree

4 files changed

+88
-105
lines changed

4 files changed

+88
-105
lines changed

jax/_src/interpreters/pxla.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2193,15 +2193,8 @@ def lower_sharding_computation(
21932193
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
21942194
len(out_shardings), len(out_layouts), len(global_out_avals))
21952195

2196-
if config.sharding_in_types.value:
2197-
# TODO(yashkatariya): Thread it via jit path and remove the None check by
2198-
# making tests go via set_mesh API always.
2199-
devices_from_context = (
2200-
None if mesh_lib.device_context.concrete_mesh is None
2201-
else mesh_lib.device_context.concrete_mesh._flat_devices_tuple)
2202-
else:
2203-
devices_from_context = (None if context_mesh is None or context_mesh.empty
2204-
else context_mesh._flat_devices_tuple)
2196+
devices_from_context = (None if context_mesh is None or context_mesh.empty
2197+
else context_mesh._flat_devices_tuple)
22052198
# Device assignment across all inputs, outputs and shardings inside jaxpr
22062199
# should be the same.
22072200
unique_intermediate_shardings = util.stable_unique(

jax/_src/pjit.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -707,9 +707,6 @@ def get_abstract_mesh(in_avals):
707707
f'Mesh for all inputs should be equal. Got one mesh: {m} and'
708708
f' another mesh: {a.sharding.mesh}')
709709
m = a.sharding.mesh # type: ignore
710-
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
711-
if m is None:
712-
return contextlib.nullcontext()
713710
assert isinstance(m, AbstractMesh)
714711
return m
715712

@@ -1791,8 +1788,12 @@ def _pjit_lower(
17911788
lowering_platforms: tuple[str, ...] | None,
17921789
lowering_parameters: mlir.LoweringParameters,
17931790
pgle_profiler: profiler.PGLEProfiler | None):
1794-
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
1795-
if resource_env is not None else (None, 'jit'))
1791+
if config.sharding_in_types.value:
1792+
mesh = mesh_lib.device_context.concrete_mesh
1793+
api_name = 'jit'
1794+
else:
1795+
mesh, api_name = ((resource_env.physical_mesh, 'pjit')
1796+
if resource_env is not None else (None, 'jit'))
17961797
return pxla.lower_sharding_computation(
17971798
jaxpr, api_name, name, in_shardings, out_shardings,
17981799
in_layouts, out_layouts, tuple(donated_invars),

jax/_src/test_util.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from jax._src import pjit as pjit_lib
5252
from jax._src import stages
5353
from jax._src import xla_bridge
54+
from jax._src import mesh as mesh_lib
5455
from jax._src.cloud_tpu_init import running_in_cloud_tpu_vm
5556
from jax._src.interpreters import mlir
5657
from jax._src.interpreters import pxla
@@ -1442,6 +1443,16 @@ def with_and_without_mesh(f):
14421443
('Mesh', (('x', 2),), (('i', 'x'),))
14431444
))(with_mesh_from_kwargs(f))
14441445

1446+
def with_user_mesh(sizes, names):
1447+
def decorator(fn):
1448+
def mesh_fn(*args, **kwargs):
1449+
mesh = create_mesh(sizes, names)
1450+
with mesh_lib.set_mesh(mesh):
1451+
return fn(*args, **kwargs, mesh=mesh)
1452+
return mesh_fn
1453+
return decorator
1454+
1455+
14451456
def create_mesh(mesh_shape, axis_names, iota_order=False):
14461457
size = math.prod(mesh_shape)
14471458
if len(jax.devices()) < size:

0 commit comments

Comments
 (0)