File tree Expand file tree Collapse file tree 2 files changed +9
-5
lines changed Expand file tree Collapse file tree 2 files changed +9
-5
lines changed Original file line number Diff line number Diff line change @@ -483,15 +483,19 @@ def __init__(self):
483483def push_mesh_context (val ):
484484 mesh_context .stack .append (val )
485485 mesh_context .mesh = val
486- jax_config .abstract_mesh_context_manager .set_local (
487- tuple (m for m in mesh_context .stack if m is not None ))
486+ # TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
487+ # Right now that leads to weird numerical issues.
488+ non_none_meshes = tuple (m for m in mesh_context .stack if m is not None )
489+ if non_none_meshes :
490+ jax_config .abstract_mesh_context_manager .set_local (non_none_meshes )
488491 return val
489492
490493def pop_mesh_context ():
491494 mesh_context .stack .pop ()
492495 mesh_context .mesh = mesh_context .stack [- 1 ]
493- jax_config .abstract_mesh_context_manager .set_local (
494- tuple (m for m in mesh_context .stack if m is not None ))
496+ non_none_meshes = tuple (m for m in mesh_context .stack if m is not None )
497+ if non_none_meshes :
498+ jax_config .abstract_mesh_context_manager .set_local (non_none_meshes )
495499
496500
497501class null_mesh_context :
Original file line number Diff line number Diff line change @@ -709,7 +709,7 @@ def get_abstract_mesh(in_avals):
709709 # TODO(yashkatariya): Remove this when mesh context can be set by the user.
710710 if m is None :
711711 return mesh_lib .null_mesh_context ()
712- assert m is not None
712+ assert isinstance ( m , AbstractMesh )
713713 return m
714714
715715
You can’t perform that action at this time.
0 commit comments