Skip to content

Commit 0d2dfea

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add a private set_mesh API to enter into sharding_in_types mode. This is how users will enable sharding in types mode (with correct axis types set too but that doesn't work yet).
Also adding a device_context so `set_mesh` sets the devices the computation should run on correctly. The device_context however enters concrete devices into tracing and lowering cache but this should be fixed with the other jax context work going on. PiperOrigin-RevId: 700537898
1 parent 1372669 commit 0d2dfea

File tree

6 files changed

+92
-22
lines changed

6 files changed

+92
-22
lines changed

jax/_src/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ def trace_context():
212212
return (axis_env_state.value, mesh_context_manager.value,
213213
xla_metadata_context_manager.value,
214214
abstract_mesh_context_manager.value,
215+
device_context.value,
215216
compute_on_context_manager.value, enable_x64.value,
216217
numpy_rank_promotion.value, default_matmul_precision.value,
217218
dynamic_shapes.value,
@@ -245,6 +246,7 @@ def trace_context():
245246
axis_env_state = ()
246247
mesh_context_manager = ()
247248
abstract_mesh_context_manager = ()
249+
device_context = ()
248250
xla_metadata_context_manager = ()
249251
compute_on_context_manager = ()
250252

@@ -255,12 +257,14 @@ def trace_context():
255257
mesh_context_manager = context.mesh_context_manager
256258
if context and context.abstract_mesh_context_manager:
257259
abstract_mesh_context_manager = context.abstract_mesh_context_manager
260+
if context and context.device_context:
261+
device_context = context.device_context
258262
if context and context.xla_metadata_context_manager:
259263
xla_metadata_context_manager = context.xla_metadata_context_manager
260264
if context and context.compute_on_context_manager:
261265
compute_on_context_manager = context.compute_on_context_manager
262266
return (axis_env_state, mesh_context_manager, abstract_mesh_context_manager,
263-
xla_metadata_context_manager,
267+
device_context, xla_metadata_context_manager,
264268
compute_on_context_manager, enable_x64.value,
265269
numpy_rank_promotion.value, default_matmul_precision.value,
266270
dynamic_shapes.value,
@@ -976,6 +980,7 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
976980
axis_env_state = config_ext.Config((), include_in_jit_key=True)
977981
mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
978982
abstract_mesh_context_manager = config_ext.Config((), include_in_jit_key=True)
983+
device_context = config_ext.Config((), include_in_jit_key=True)
979984
compute_on_context_manager = config_ext.Config((), include_in_jit_key=True)
980985
xla_metadata_context_manager = config_ext.Config((), include_in_jit_key=True)
981986
else:
@@ -1019,6 +1024,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
10191024
axis_env_state: Hashable = ()
10201025
mesh_context_manager: Hashable = ()
10211026
abstract_mesh_context_manager: Hashable = ()
1027+
device_context: Hashable = ()
10221028
compute_on_context_manager: Hashable = ()
10231029
xla_metadata_context_manager: Hashable = ()
10241030

@@ -1086,6 +1092,7 @@ def set_local(self, value):
10861092
axis_env_state = JitConfig('axis_env_state')
10871093
mesh_context_manager = JitConfig('mesh_context_manager')
10881094
abstract_mesh_context_manager = JitConfig('abstract_mesh_context_manager')
1095+
device_context = JitConfig('device_context')
10891096
compute_on_context_manager = JitConfig('compute_on_context_manager')
10901097
xla_metadata_context_manager = JitConfig('xla_metadata_context_manager')
10911098

jax/_src/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1605,7 +1605,7 @@ def get_sharding(sharding, ndim):
16051605
assert len(sharding.spec) == ndim
16061606
return sharding
16071607

1608-
context_mesh = mesh_lib.mesh_context.mesh
1608+
context_mesh = mesh_lib.abstract_mesh_context.mesh
16091609
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
16101610
# code.
16111611
if context_mesh is None:

jax/_src/interpreters/pxla.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2193,8 +2193,15 @@ def lower_sharding_computation(
21932193
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
21942194
len(out_shardings), len(out_layouts), len(global_out_avals))
21952195

2196-
devices_from_context = (None if context_mesh is None or context_mesh.empty
2197-
else context_mesh._flat_devices_tuple)
2196+
if config.sharding_in_types.value:
2197+
# TODO(yashkatariya): Thread it via jit path and remove the None check by
2198+
# making tests go via set_mesh API always.
2199+
devices_from_context = (
2200+
None if mesh_lib.device_context.concrete_mesh is None
2201+
else mesh_lib.device_context.concrete_mesh._flat_devices_tuple)
2202+
else:
2203+
devices_from_context = (None if context_mesh is None or context_mesh.empty
2204+
else context_mesh._flat_devices_tuple)
21982205
# Device assignment across all inputs, outputs and shardings inside jaxpr
21992206
# should be the same.
22002207
unique_intermediate_shardings = util.stable_unique(

jax/_src/mesh.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -455,10 +455,10 @@ def local_mesh(self):
455455
_raise_value_error("local_mesh")
456456

457457
def __enter__(self):
458-
return push_mesh_context(self)
458+
return push_abstract_mesh_context(self)
459459

460460
def __exit__(self, exc_type, exc_value, traceback):
461-
pop_mesh_context()
461+
pop_abstract_mesh_context()
462462
return False
463463

464464
@staticmethod
@@ -473,36 +473,70 @@ def _raise_value_error(name):
473473
raise ValueError(f"AbstractMesh does not implement {name}")
474474

475475

476-
class MeshContext(threading.local):
476+
class AbstractMeshContext(threading.local):
477477
def __init__(self):
478478
self.stack = [None]
479479
self.mesh = self.stack[-1]
480480

481-
mesh_context = MeshContext()
481+
abstract_mesh_context = AbstractMeshContext()
482482

483-
def push_mesh_context(val):
484-
mesh_context.stack.append(val)
485-
mesh_context.mesh = val
483+
def push_abstract_mesh_context(val):
484+
abstract_mesh_context.stack.append(val)
485+
abstract_mesh_context.mesh = val
486486
# TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
487487
# Right now that leads to weird numerical issues.
488-
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
488+
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
489+
if m is not None)
489490
if non_none_meshes:
490491
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
491492
return val
492493

493-
def pop_mesh_context():
494-
mesh_context.stack.pop()
495-
mesh_context.mesh = mesh_context.stack[-1]
496-
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
494+
def pop_abstract_mesh_context():
495+
abstract_mesh_context.stack.pop()
496+
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
497+
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
498+
if m is not None)
497499
if non_none_meshes:
498500
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
499501

500502

501503
class null_mesh_context:
502504

503505
def __enter__(self):
504-
return push_mesh_context(None)
506+
return push_abstract_mesh_context(None)
505507

506508
def __exit__(self, *excinfo):
507-
pop_mesh_context()
509+
pop_abstract_mesh_context()
508510
return False
511+
512+
513+
@contextlib.contextmanager
514+
def set_mesh(mesh: Mesh):
515+
with (mesh.abstract_mesh, jax_config.sharding_in_types(True),
516+
enter_device_context(mesh)):
517+
yield
518+
519+
520+
class DeviceContext(threading.local):
521+
def __init__(self):
522+
self.stack = [None]
523+
self.concrete_mesh = self.stack[-1]
524+
525+
device_context = DeviceContext()
526+
527+
528+
@contextlib.contextmanager
529+
def enter_device_context(mesh: Mesh):
530+
device_context.stack.append(mesh)
531+
device_context.concrete_mesh = mesh
532+
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
533+
if non_none_meshes:
534+
jax_config.device_context.set_local(non_none_meshes)
535+
try:
536+
yield
537+
finally:
538+
device_context.stack.pop()
539+
device_context.concrete_mesh = device_context.stack[-1]
540+
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
541+
if non_none_meshes:
542+
jax_config.device_context.set_local(non_none_meshes)

jax/_src/pjit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -644,8 +644,8 @@ def _infer_params_impl(
644644
attr_token = _attr_token(flat_fun, in_type)
645645

646646
abstract_mesh = (
647-
get_abstract_mesh(in_type) if mesh_lib.mesh_context.mesh is None
648-
else mesh_lib.mesh_context.mesh)
647+
get_abstract_mesh(in_type) if mesh_lib.abstract_mesh_context.mesh is None
648+
else mesh_lib.abstract_mesh_context.mesh)
649649
with abstract_mesh:
650650
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
651651
flat_fun, in_type, attr_token, dbg,

tests/pjit_test.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4622,6 +4622,28 @@ def f(x):
46224622
ins, _ = f.lower(np.arange(8)).compile().input_shardings
46234623
self.assertEqual(ins[0], SingleDeviceSharding(jax.devices()[0]))
46244624

4625+
def test_sharding_in_types_with_set_mesh(self):
4626+
if config.use_shardy_partitioner.value:
4627+
self.skipTest("ShiT doesn't work with shardy")
4628+
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
4629+
with mesh_lib.set_mesh(mesh):
4630+
np_inp = np.arange(16.).reshape(8, 2)
4631+
s = NamedSharding(mesh, P('x', 'y'))
4632+
arr = jax.device_put(np_inp, s)
4633+
4634+
@jax.jit
4635+
def f(x):
4636+
self.assertEqual(x.sharding.spec, s.spec)
4637+
x = x * 2
4638+
self.assertEqual(x.sharding.spec, s.spec)
4639+
x = x * x
4640+
self.assertEqual(x.sharding.spec, s.spec)
4641+
return x
4642+
4643+
out = f(arr)
4644+
self.assertEqual(out.sharding, s)
4645+
self.assertArraysEqual(out, (np_inp * 2) * (np_inp * 2))
4646+
46254647

46264648
def spec_regex(s):
46274649
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
@@ -5229,7 +5251,7 @@ def test_shard_map_full_manual(self):
52295251
def g(x, y):
52305252
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
52315253
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
5232-
self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective)
5254+
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
52335255
return x * y
52345256

52355257
@jax.jit
@@ -5254,7 +5276,7 @@ def test_shard_map_dot(self):
52545276
def g(x, y):
52555277
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
52565278
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
5257-
self.assertTrue(mesh_lib.mesh_context.mesh._are_all_axes_collective)
5279+
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
52585280
allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True)
52595281
z = x @ allgatherd_y
52605282
return jax.lax.psum(z, axis_name='y')

0 commit comments

Comments
 (0)