@@ -244,6 +244,7 @@ def trace_context():
244244 tls = jax_jit .thread_local_state ()
245245 axis_env_state = ()
246246 mesh_context_manager = ()
247+ abstract_mesh_context_manager = ()
247248 xla_metadata_context_manager = ()
248249 compute_on_context_manager = ()
249250
@@ -252,11 +253,14 @@ def trace_context():
252253 axis_env_state = context .axis_env_state
253254 if context and context .mesh_context_manager :
254255 mesh_context_manager = context .mesh_context_manager
256+ if context and context .abstract_mesh_context_manager :
257+ abstract_mesh_context_manager = context .abstract_mesh_context_manager
255258 if context and context .xla_metadata_context_manager :
256259 xla_metadata_context_manager = context .xla_metadata_context_manager
257260 if context and context .compute_on_context_manager :
258261 compute_on_context_manager = context .compute_on_context_manager
259- return (axis_env_state , mesh_context_manager , xla_metadata_context_manager ,
262+ return (axis_env_state , mesh_context_manager , abstract_mesh_context_manager ,
263+ xla_metadata_context_manager ,
260264 compute_on_context_manager , enable_x64 .value ,
261265 numpy_rank_promotion .value , default_matmul_precision .value ,
262266 dynamic_shapes .value ,
@@ -1014,6 +1018,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
10141018 trace_state : Any | None = None
10151019 axis_env_state : Hashable = ()
10161020 mesh_context_manager : Hashable = ()
1021+ abstract_mesh_context_manager : Hashable = ()
10171022 compute_on_context_manager : Hashable = ()
10181023 xla_metadata_context_manager : Hashable = ()
10191024
@@ -1080,6 +1085,7 @@ def set_local(self, value):
10801085 trace_state = JitConfig ('trace_state' )
10811086 axis_env_state = JitConfig ('axis_env_state' )
10821087 mesh_context_manager = JitConfig ('mesh_context_manager' )
1088+ abstract_mesh_context_manager = JitConfig ('abstract_mesh_context_manager' )
10831089 compute_on_context_manager = JitConfig ('compute_on_context_manager' )
10841090 xla_metadata_context_manager = JitConfig ('xla_metadata_context_manager' )
10851091
0 commit comments