Skip to content

Commit 6763fcf

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix a weird interaction with set_local and empty tuples passed to it.
PiperOrigin-RevId: 700392735
1 parent e453fa1 commit 6763fcf

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

jax/_src/mesh.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,15 +483,19 @@ def __init__(self):
483483
def push_mesh_context(val):
484484
mesh_context.stack.append(val)
485485
mesh_context.mesh = val
486-
jax_config.abstract_mesh_context_manager.set_local(
487-
tuple(m for m in mesh_context.stack if m is not None))
486+
# TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
487+
# Right now that leads to weird numerical issues.
488+
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
489+
if non_none_meshes:
490+
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
488491
return val
489492

490493
def pop_mesh_context():
491494
mesh_context.stack.pop()
492495
mesh_context.mesh = mesh_context.stack[-1]
493-
jax_config.abstract_mesh_context_manager.set_local(
494-
tuple(m for m in mesh_context.stack if m is not None))
496+
non_none_meshes = tuple(m for m in mesh_context.stack if m is not None)
497+
if non_none_meshes:
498+
jax_config.abstract_mesh_context_manager.set_local(non_none_meshes)
495499

496500

497501
class null_mesh_context:

jax/_src/pjit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def get_abstract_mesh(in_avals):
709709
# TODO(yashkatariya): Remove this when mesh context can be set by the user.
710710
if m is None:
711711
return mesh_lib.null_mesh_context()
712-
assert m is not None
712+
assert isinstance(m, AbstractMesh)
713713
return m
714714

715715

0 commit comments

Comments
 (0)