Skip to content

Commit 627debc

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Create a null_mesh_context internal context manager to handle null contexts properly.
PiperOrigin-RevId: 700167406
1 parent 59e13f8 commit 627debc

File tree

3 files changed

+29
-14
lines changed

3 files changed

+29
-14
lines changed

jax/_src/mesh.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -455,17 +455,10 @@ def local_mesh(self):
455455
_raise_value_error("local_mesh")
456456

457457
def __enter__(self):
458-
mesh_context.stack.append(self)
459-
mesh_context.mesh = self
460-
jax_config.abstract_mesh_context_manager.set_local(
461-
tuple(m for m in mesh_context.stack if m is not None))
462-
return self
458+
return push_mesh_context(self)
463459

464460
def __exit__(self, exc_type, exc_value, traceback):
465-
mesh_context.stack.pop()
466-
mesh_context.mesh = mesh_context.stack[-1]
467-
jax_config.abstract_mesh_context_manager.set_local(
468-
tuple(m for m in mesh_context.stack if m is not None))
461+
pop_mesh_context()
469462
return False
470463

471464
@staticmethod
@@ -486,3 +479,26 @@ def __init__(self):
486479
self.mesh = self.stack[-1]
487480

488481
mesh_context = MeshContext()
482+
483+
def push_mesh_context(val):
484+
mesh_context.stack.append(val)
485+
mesh_context.mesh = val
486+
jax_config.abstract_mesh_context_manager.set_local(
487+
tuple(m for m in mesh_context.stack if m is not None))
488+
return val
489+
490+
def pop_mesh_context():
491+
mesh_context.stack.pop()
492+
mesh_context.mesh = mesh_context.stack[-1]
493+
jax_config.abstract_mesh_context_manager.set_local(
494+
tuple(m for m in mesh_context.stack if m is not None))
495+
496+
497+
class null_mesh_context:
498+
499+
def __enter__(self):
500+
return push_mesh_context(None)
501+
502+
def __exit__(self, *excinfo):
503+
pop_mesh_context()
504+
return False

jax/_src/pjit.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from collections import defaultdict
1818
from collections.abc import Callable, Sequence, Iterable
19-
import contextlib
2019
import dataclasses
2120
from functools import partial
2221
import inspect
@@ -696,7 +695,7 @@ def _infer_params_impl(
696695

697696
def get_abstract_mesh(in_avals):
698697
if not config.sharding_in_types.value:
699-
return contextlib.nullcontext()
698+
return mesh_lib.null_mesh_context()
700699
m = None
701700
for a in in_avals:
702701
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
@@ -709,7 +708,7 @@ def get_abstract_mesh(in_avals):
709708
m = a.sharding.mesh # type: ignore
710709
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
711710
if m is None:
712-
return contextlib.nullcontext()
711+
return mesh_lib.null_mesh_context()
713712
assert m is not None
714713
return m
715714

jax/_src/stages.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
"""
3131
from __future__ import annotations
3232

33-
import contextlib
3433
import functools
3534
from collections.abc import Sequence
3635
from dataclasses import dataclass
@@ -44,6 +43,7 @@
4443
from jax._src import traceback_util
4544
from jax._src import tree_util
4645
from jax._src import util
46+
from jax._src import mesh as mesh_lib
4747
from jax._src.sharding_impls import UnspecifiedValue, AUTO
4848
from jax._src.layout import Layout
4949
from jax._src.interpreters import mlir
@@ -717,7 +717,7 @@ class Traced(Stage):
717717
"_args_flat", "_arg_names", "_num_consts"]
718718

719719
def __init__(self, jaxpr: core.ClosedJaxpr, args_info, fun_name, out_tree,
720-
lower_callable, abstract_mesh=contextlib.nullcontext(),
720+
lower_callable, abstract_mesh=mesh_lib.null_mesh_context(),
721721
args_flat=None, arg_names=None, num_consts: int = 0):
722722
self.jaxpr = jaxpr
723723
self.args_info = args_info

0 commit comments

Comments
 (0)