Skip to content

Commit a735bf8

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Simply abstract_mesh and device_context context managers and handle everything via their corresponding configs in config.py
PiperOrigin-RevId: 702852769
1 parent 1a3c9c4 commit a735bf8

File tree

7 files changed

+43
-54
lines changed

7 files changed

+43
-54
lines changed

jax/_src/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1079,15 +1079,21 @@ class JitConfig:
10791079
def __init__(self, name):
10801080
self._name = name
10811081

1082+
@property
10821083
def value(self):
1083-
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
1084+
return self.get_local()
10841085

10851086
def get_local(self):
10861087
return getattr(jax_jit.thread_local_state().extra_jit_context, self._name)
10871088

10881089
def set_local(self, value):
10891090
update_thread_local_jit_state(**{self._name: value})
10901091

1092+
def swap_local(self, new_value):
1093+
prev_value = self.value
1094+
self.set_local(new_value)
1095+
return prev_value
1096+
10911097
trace_state = JitConfig('trace_state')
10921098
axis_env_state = JitConfig('axis_env_state')
10931099
mesh_context_manager = JitConfig('mesh_context_manager')

jax/_src/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,10 +1605,10 @@ def get_sharding(sharding, ndim):
16051605
assert len(sharding.spec) == ndim
16061606
return sharding
16071607

1608-
context_mesh = mesh_lib.abstract_mesh_context.mesh
1608+
context_mesh = mesh_lib.get_abstract_mesh()
16091609
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
16101610
# code.
1611-
if context_mesh is None:
1611+
if not context_mesh:
16121612
return None
16131613
assert sharding is None
16141614
return NamedSharding(context_mesh, P(*[None] * ndim))

jax/_src/mesh.py

Lines changed: 19 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -454,18 +454,6 @@ def local_devices(self):
454454
def local_mesh(self):
455455
_raise_value_error("local_mesh")
456456

457-
def __enter__(self):
458-
abstract_mesh_context.stack.append(self)
459-
abstract_mesh_context.mesh = self
460-
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
461-
return self
462-
463-
def __exit__(self, exc_type, exc_value, traceback):
464-
abstract_mesh_context.stack.pop()
465-
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
466-
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
467-
return False
468-
469457
@staticmethod
470458
def _extremely_unsafe_enter_tracing_context(mesh: AbstractMesh):
471459
jax_config.abstract_mesh_context_manager.set_local(mesh)
@@ -478,37 +466,32 @@ def _raise_value_error(name):
478466
raise ValueError(f"AbstractMesh does not implement {name}")
479467

480468

481-
class AbstractMeshContext(threading.local):
482-
def __init__(self):
483-
self.stack = [None]
484-
self.mesh = self.stack[-1]
469+
@contextlib.contextmanager
470+
def set_abstract_mesh(mesh: AbstractMesh):
471+
prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh)
472+
try:
473+
yield
474+
finally:
475+
jax_config.abstract_mesh_context_manager.set_local(prev_val)
485476

486-
abstract_mesh_context = AbstractMeshContext()
477+
def get_abstract_mesh():
478+
return jax_config.abstract_mesh_context_manager.value
487479

488480

489481
@contextlib.contextmanager
490-
def set_mesh(mesh: Mesh):
491-
with (mesh.abstract_mesh, jax_config.sharding_in_types(True),
492-
enter_device_context(mesh)):
482+
def set_concrete_mesh(mesh: Mesh):
483+
prev_val = jax_config.device_context.swap_local(mesh)
484+
try:
493485
yield
486+
finally:
487+
jax_config.device_context.set_local(prev_val)
494488

495-
496-
class DeviceContext(threading.local):
497-
def __init__(self):
498-
self.stack = [None]
499-
self.concrete_mesh = self.stack[-1]
500-
501-
device_context = DeviceContext()
489+
def get_concrete_mesh():
490+
return jax_config.device_context.value
502491

503492

504493
@contextlib.contextmanager
505-
def enter_device_context(mesh: Mesh):
506-
device_context.stack.append(mesh)
507-
device_context.concrete_mesh = mesh
508-
jax_config.device_context.set_local(device_context.concrete_mesh)
509-
try:
494+
def set_mesh(mesh: Mesh):
495+
with (set_abstract_mesh(mesh.abstract_mesh),
496+
jax_config.sharding_in_types(True), set_concrete_mesh(mesh)):
510497
yield
511-
finally:
512-
device_context.stack.pop()
513-
device_context.concrete_mesh = device_context.stack[-1]
514-
jax_config.device_context.set_local(device_context.concrete_mesh)

jax/_src/pjit.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from collections import defaultdict
1818
from collections.abc import Callable, Sequence, Iterable
19-
import contextlib
2019
import dataclasses
2120
from functools import partial
2221
import inspect
@@ -187,7 +186,7 @@ def _python_pjit_helper(fun, jit_info, *args, **kwargs):
187186
try:
188187
# TODO(yashkatariya): Maybe thread this into pjit params like resource_env
189188
# and set the context manager down the stack?
190-
with p.abstract_mesh:
189+
with mesh_lib.set_abstract_mesh(p.abstract_mesh):
191190
if (core.trace_state_clean() and
192191
not config.debug_key_reuse.value and
193192
not config.data_dependent_tracing_fallback.value):
@@ -645,9 +644,9 @@ def _infer_params_impl(
645644
attr_token = _attr_token(flat_fun, in_type)
646645

647646
abstract_mesh = (
648-
get_abstract_mesh(in_type) if mesh_lib.abstract_mesh_context.mesh is None
649-
else mesh_lib.abstract_mesh_context.mesh)
650-
with abstract_mesh:
647+
get_abstract_mesh_from_avals(in_type)
648+
if not mesh_lib.get_abstract_mesh() else mesh_lib.get_abstract_mesh())
649+
with mesh_lib.set_abstract_mesh(abstract_mesh):
651650
jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
652651
flat_fun, in_type, attr_token, dbg,
653652
HashableFunction(res_paths, closure=()),
@@ -694,9 +693,9 @@ def _infer_params_impl(
694693
attrs_tracked, abstract_mesh), args_flat
695694

696695

697-
def get_abstract_mesh(in_avals):
696+
def get_abstract_mesh_from_avals(in_avals):
698697
if not config.sharding_in_types.value:
699-
return contextlib.nullcontext()
698+
return None
700699
m = None
701700
for a in in_avals:
702701
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
@@ -1789,7 +1788,8 @@ def _pjit_lower(
17891788
lowering_parameters: mlir.LoweringParameters,
17901789
pgle_profiler: profiler.PGLEProfiler | None):
17911790
if config.sharding_in_types.value:
1792-
mesh = mesh_lib.device_context.concrete_mesh
1791+
cur_mesh = mesh_lib.get_concrete_mesh()
1792+
mesh = cur_mesh if isinstance(cur_mesh, mesh_lib.Mesh) else None
17931793
api_name = 'jit'
17941794
else:
17951795
mesh, api_name = ((resource_env.physical_mesh, 'pjit')

jax/_src/stages.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
"""
3131
from __future__ import annotations
3232

33-
import contextlib
3433
import functools
3534
from collections.abc import Sequence
3635
from dataclasses import dataclass
@@ -44,6 +43,7 @@
4443
from jax._src import traceback_util
4544
from jax._src import tree_util
4645
from jax._src import util
46+
from jax._src import mesh as mesh_lib
4747
from jax._src.sharding_impls import UnspecifiedValue, AUTO
4848
from jax._src.layout import Layout
4949
from jax._src.interpreters import mlir
@@ -717,7 +717,7 @@ class Traced(Stage):
717717
"_args_flat", "_arg_names", "_num_consts"]
718718

719719
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
720-
lower_callable, abstract_mesh=contextlib.nullcontext(),
720+
lower_callable, abstract_mesh=None,
721721
args_flat=None, arg_names=None, num_consts: int = 0):
722722
self.jaxpr = jaxpr
723723
self.args_info = args_info
@@ -747,7 +747,7 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
747747
try:
748748
# TODO(yashkatariya): Maybe thread this into pjit params like resource_env
749749
# and set the context manager down the stack?
750-
with self._abstract_mesh:
750+
with mesh_lib.set_abstract_mesh(self._abstract_mesh):
751751
lowering = new_callable()
752752
except pxla.DeviceAssignmentMismatchError as e:
753753
fails, = e.args

jax/experimental/shard_map.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
from jax._src import traceback_util
4747
from jax._src import util
4848
from jax._src.core import Tracer
49-
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes
49+
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes, set_abstract_mesh
5050
from jax._src.api import _shared_code_pmap, _prepare_pmap
5151
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
5252
windowed_reductions, convolution, fft, linalg,
@@ -484,7 +484,7 @@ def _shard_map_staging(
484484
in_avals = [t.aval for t in in_tracers]
485485
in_avals_ = map(partial(_shard_aval, mesh), in_names, in_avals)
486486
with (core.extend_axis_env_nd(list(mesh.shape.items())),
487-
pjit.get_abstract_mesh(in_avals_)):
487+
set_abstract_mesh(pjit.get_abstract_mesh_from_avals(in_avals_))):
488488
jaxpr, out_avals_, consts, () = pe.trace_to_jaxpr_dynamic(f, in_avals_)
489489
_check_names(out_names_thunk(), out_avals_)
490490
if check_rep:

tests/pjit_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5237,7 +5237,7 @@ def test_shard_map_full_manual(self, mesh):
52375237
def g(x, y):
52385238
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
52395239
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
5240-
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
5240+
self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective)
52415241
return x * y
52425242

52435243
@jax.jit
@@ -5262,7 +5262,7 @@ def test_shard_map_dot(self, mesh):
52625262
def g(x, y):
52635263
self.assertTrue(x.sharding.mesh._are_all_axes_collective)
52645264
self.assertTrue(y.sharding.mesh._are_all_axes_collective)
5265-
self.assertTrue(mesh_lib.abstract_mesh_context.mesh._are_all_axes_collective)
5265+
self.assertTrue(mesh_lib.get_abstract_mesh()._are_all_axes_collective)
52665266
allgatherd_y = jax.lax.all_gather(y, axis_name='x', axis=1, tiled=True)
52675267
z = x @ allgatherd_y
52685268
return jax.lax.psum(z, axis_name='y')

0 commit comments

Comments
 (0)