@@ -455,10 +455,15 @@ def local_mesh(self):
455455 _raise_value_error ("local_mesh" )
456456
457457 def __enter__ (self ):
458- return push_abstract_mesh_context (self )
458+ abstract_mesh_context .stack .append (self )
459+ abstract_mesh_context .mesh = self
460+ jax_config .abstract_mesh_context_manager .set_local (abstract_mesh_context .mesh )
461+ return self
459462
460463 def __exit__ (self , exc_type , exc_value , traceback ):
461- pop_abstract_mesh_context ()
464+ abstract_mesh_context .stack .pop ()
465+ abstract_mesh_context .mesh = abstract_mesh_context .stack [- 1 ]
466+ jax_config .abstract_mesh_context_manager .set_local (abstract_mesh_context .mesh )
462467 return False
463468
464469 @staticmethod
@@ -480,35 +485,6 @@ def __init__(self):
480485
481486abstract_mesh_context = AbstractMeshContext ()
482487
483- def push_abstract_mesh_context (val ):
484- abstract_mesh_context .stack .append (val )
485- abstract_mesh_context .mesh = val
486- # TODO(yashkatariya): Allow setting empty tuples and tuples with None in them.
487- # Right now that leads to weird numerical issues.
488- non_none_meshes = tuple (m for m in abstract_mesh_context .stack
489- if m is not None )
490- if non_none_meshes :
491- jax_config .abstract_mesh_context_manager .set_local (non_none_meshes )
492- return val
493-
494- def pop_abstract_mesh_context ():
495- abstract_mesh_context .stack .pop ()
496- abstract_mesh_context .mesh = abstract_mesh_context .stack [- 1 ]
497- non_none_meshes = tuple (m for m in abstract_mesh_context .stack
498- if m is not None )
499- if non_none_meshes :
500- jax_config .abstract_mesh_context_manager .set_local (non_none_meshes )
501-
502-
503- class null_mesh_context :
504-
505- def __enter__ (self ):
506- return push_abstract_mesh_context (None )
507-
508- def __exit__ (self , * excinfo ):
509- pop_abstract_mesh_context ()
510- return False
511-
512488
513489@contextlib .contextmanager
514490def set_mesh (mesh : Mesh ):
@@ -529,14 +505,10 @@ def __init__(self):
529505def enter_device_context (mesh : Mesh ):
530506 device_context .stack .append (mesh )
531507 device_context .concrete_mesh = mesh
532- non_none_meshes = tuple (m for m in device_context .stack if m is not None )
533- if non_none_meshes :
534- jax_config .device_context .set_local (non_none_meshes )
508+ jax_config .device_context .set_local (device_context .concrete_mesh )
535509 try :
536510 yield
537511 finally :
538512 device_context .stack .pop ()
539513 device_context .concrete_mesh = device_context .stack [- 1 ]
540- non_none_meshes = tuple (m for m in device_context .stack if m is not None )
541- if non_none_meshes :
542- jax_config .device_context .set_local (non_none_meshes )
514+ jax_config .device_context .set_local (device_context .concrete_mesh )
0 commit comments