@@ -454,18 +454,6 @@ def local_devices(self):
454454 def local_mesh (self ):
455455 _raise_value_error ("local_mesh" )
456456
457- def __enter__ (self ):
458- abstract_mesh_context .stack .append (self )
459- abstract_mesh_context .mesh = self
460- jax_config .abstract_mesh_context_manager .set_local (abstract_mesh_context .mesh )
461- return self
462-
463- def __exit__ (self , exc_type , exc_value , traceback ):
464- abstract_mesh_context .stack .pop ()
465- abstract_mesh_context .mesh = abstract_mesh_context .stack [- 1 ]
466- jax_config .abstract_mesh_context_manager .set_local (abstract_mesh_context .mesh )
467- return False
468-
469457 @staticmethod
470458 def _extremely_unsafe_enter_tracing_context (mesh : AbstractMesh ):
471459 jax_config .abstract_mesh_context_manager .set_local (mesh )
@@ -478,37 +466,32 @@ def _raise_value_error(name):
478466 raise ValueError (f"AbstractMesh does not implement { name } " )
479467
480468
481- class AbstractMeshContext (threading .local ):
482- def __init__ (self ):
483- self .stack = [None ]
484- self .mesh = self .stack [- 1 ]
469+ @contextlib .contextmanager
470+ def set_abstract_mesh (mesh : AbstractMesh ):
471+ prev_val = jax_config .abstract_mesh_context_manager .swap_local (mesh )
472+ try :
473+ yield
474+ finally :
475+ jax_config .abstract_mesh_context_manager .set_local (prev_val )
485476
486- abstract_mesh_context = AbstractMeshContext ()
477+ def get_abstract_mesh ():
478+ return jax_config .abstract_mesh_context_manager .value
487479
488480
489481@contextlib .contextmanager
490- def set_mesh (mesh : Mesh ):
491- with ( mesh . abstract_mesh , jax_config .sharding_in_types ( True ),
492- enter_device_context ( mesh )) :
482+ def set_concrete_mesh (mesh : Mesh ):
483+ prev_val = jax_config .device_context . swap_local ( mesh )
484+ try :
493485 yield
486+ finally :
487+ jax_config .device_context .set_local (prev_val )
494488
495-
496- class DeviceContext (threading .local ):
497- def __init__ (self ):
498- self .stack = [None ]
499- self .concrete_mesh = self .stack [- 1 ]
500-
501- device_context = DeviceContext ()
489+ def get_concrete_mesh ():
490+ return jax_config .device_context .value
502491
503492
504493@contextlib .contextmanager
505- def enter_device_context (mesh : Mesh ):
506- device_context .stack .append (mesh )
507- device_context .concrete_mesh = mesh
508- jax_config .device_context .set_local (device_context .concrete_mesh )
509- try :
494+ def set_mesh (mesh : Mesh ):
495+ with (set_abstract_mesh (mesh .abstract_mesh ),
496+ jax_config .sharding_in_types (True ), set_concrete_mesh (mesh )):
510497 yield
511- finally :
512- device_context .stack .pop ()
513- device_context .concrete_mesh = device_context .stack [- 1 ]
514- jax_config .device_context .set_local (device_context .concrete_mesh )
0 commit comments