Skip to content

Commit 653f654

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix the broken behavior of not resetting the abstract_mesh and device_context properly during __exit__.
PiperOrigin-RevId: 702762477
1 parent 681b9c2 commit 653f654

File tree

3 files changed

+14
-41
lines changed

3 files changed

+14
-41
lines changed

jax/_src/mesh.py

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -455,10 +455,15 @@ def local_mesh(self):
455455
_raise_value_error("local_mesh")
456456

457457
def __enter__(self):
458-
return push_abstract_mesh_context(self)
458+
abstract_mesh_context.stack.append(self)
459+
abstract_mesh_context.mesh = self
460+
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
461+
return self
459462

460463
def __exit__(self, exc_type, exc_value, traceback):
461-
pop_abstract_mesh_context()
464+
abstract_mesh_context.stack.pop()
465+
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
466+
jax_config.abstract_mesh_context_manager.set_local(abstract_mesh_context.mesh)
462467
return False
463468

464469
@staticmethod
@@ -480,35 +485,6 @@ def __init__(self):
480485

481486
abstract_mesh_context = AbstractMeshContext()
482487

483-
def push_abstract_mesh_context(val):
484-
abstract_mesh_context.stack.append(val)
485-
abstract_mesh_context.mesh = val
486-
# TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
487-
# Right now that leads to weird numerical issues.
488-
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
489-
if m is not None)
490-
if non_none_meshes:
491-
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
492-
return val
493-
494-
def pop_abstract_mesh_context():
495-
abstract_mesh_context.stack.pop()
496-
abstract_mesh_context.mesh = abstract_mesh_context.stack[-1]
497-
non_none_meshes = tuple(m for m in abstract_mesh_context.stack
498-
if m is not None)
499-
if non_none_meshes:
500-
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
501-
502-
503-
class null_mesh_context:
504-
505-
def __enter__(self):
506-
return push_abstract_mesh_context(None)
507-
508-
def __exit__(self, *excinfo):
509-
pop_abstract_mesh_context()
510-
return False
511-
512488

513489
@contextlib.contextmanager
514490
def set_mesh(mesh: Mesh):
@@ -529,14 +505,10 @@ def __init__(self):
529505
def enter_device_context(mesh: Mesh):
530506
device_context.stack.append(mesh)
531507
device_context.concrete_mesh = mesh
532-
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
533-
if non_none_meshes:
534-
jax_config.device_context.set_local(non_none_meshes)
508+
jax_config.device_context.set_local(device_context.concrete_mesh)
535509
try:
536510
yield
537511
finally:
538512
device_context.stack.pop()
539513
device_context.concrete_mesh = device_context.stack[-1]
540-
non_none_meshes = tuple(m for m in device_context.stack if m is not None)
541-
if non_none_meshes:
542-
jax_config.device_context.set_local(non_none_meshes)
514+
jax_config.device_context.set_local(device_context.concrete_mesh)

jax/_src/pjit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from collections import defaultdict
1818
from collections.abc import Callable, Sequence, Iterable
19+
import contextlib
1920
import dataclasses
2021
from functools import partial
2122
import inspect
@@ -695,7 +696,7 @@ def _infer_params_impl(
695696

696697
def get_abstract_mesh(in_avals):
697698
if not config.sharding_in_types.value:
698-
return mesh_lib.null_mesh_context()
699+
return contextlib.nullcontext()
699700
m = None
700701
for a in in_avals:
701702
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
@@ -708,7 +709,7 @@ def get_abstract_mesh(in_avals):
708709
m = a.sharding.mesh # type: ignore
709710
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
710711
if m is None:
711-
return mesh_lib.null_mesh_context()
712+
return contextlib.nullcontext()
712713
assert isinstance(m, AbstractMesh)
713714
return m
714715

jax/_src/stages.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
"""
3131
from __future__ import annotations
3232

33+
import contextlib
3334
import functools
3435
from collections.abc import Sequence
3536
from dataclasses import dataclass
@@ -43,7 +44,6 @@
4344
from jax._src import traceback_util
4445
from jax._src import tree_util
4546
from jax._src import util
46-
from jax._src import mesh as mesh_lib
4747
from jax._src.sharding_impls import UnspecifiedValue, AUTO
4848
from jax._src.layout import Layout
4949
from jax._src.interpreters import mlir
@@ -717,7 +717,7 @@ class Traced(Stage):
717717
"_args_flat", "_arg_names", "_num_consts"]
718718

719719
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
720-
lower_callable, abstract_mesh=mesh_lib.null_mesh_context(),
720+
lower_callable, abstract_mesh=contextlib.nullcontext(),
721721
args_flat=None, arg_names=None, num_consts: int = 0):
722722
self.jaxpr = jaxpr
723723
self.args_info = args_info

0 commit comments

Comments
 (0)