Skip to content

Commit c35f8b2

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Add abstract mesh context manager to trace_context in the fallback path too (which will be deleted after jax 0.4.36 release)
PiperOrigin-RevId: 700006186
1 parent aa05dc0 commit c35f8b2

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

jax/_src/config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def trace_context():
244244
tls = jax_jit.thread_local_state()
245245
axis_env_state = ()
246246
mesh_context_manager = ()
247+
abstract_mesh_context_manager = ()
247248
xla_metadata_context_manager = ()
248249
compute_on_context_manager = ()
249250

@@ -252,11 +253,14 @@ def trace_context():
252253
axis_env_state = context.axis_env_state
253254
if context and context.mesh_context_manager:
254255
mesh_context_manager = context.mesh_context_manager
256+
if context and context.abstract_mesh_context_manager:
257+
abstract_mesh_context_manager = context.abstract_mesh_context_manager
255258
if context and context.xla_metadata_context_manager:
256259
xla_metadata_context_manager = context.xla_metadata_context_manager
257260
if context and context.compute_on_context_manager:
258261
compute_on_context_manager = context.compute_on_context_manager
259-
return (axis_env_state, mesh_context_manager, xla_metadata_context_manager,
262+
return (axis_env_state, mesh_context_manager, abstract_mesh_context_manager,
263+
xla_metadata_context_manager,
260264
compute_on_context_manager, enable_x64.value,
261265
numpy_rank_promotion.value, default_matmul_precision.value,
262266
dynamic_shapes.value,
@@ -1014,6 +1018,7 @@ class _ThreadLocalExtraJitContext(NamedTuple):
10141018
trace_state: Any | None = None
10151019
axis_env_state: Hashable = ()
10161020
mesh_context_manager: Hashable = ()
1021+
abstract_mesh_context_manager: Hashable = ()
10171022
compute_on_context_manager: Hashable = ()
10181023
xla_metadata_context_manager: Hashable = ()
10191024

@@ -1080,6 +1085,7 @@ def set_local(self, value):
10801085
trace_state = JitConfig('trace_state')
10811086
axis_env_state = JitConfig('axis_env_state')
10821087
mesh_context_manager = JitConfig('mesh_context_manager')
1088+
abstract_mesh_context_manager = JitConfig('abstract_mesh_context_manager')
10831089
compute_on_context_manager = JitConfig('compute_on_context_manager')
10841090
xla_metadata_context_manager = JitConfig('xla_metadata_context_manager')
10851091

0 commit comments

Comments
 (0)