Skip to content

Commit 1e22149

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Fix the breakage caused by deleted enable_memories config
PiperOrigin-RevId: 707331603
1 parent cca9afa commit 1e22149

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

jax/_src/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def trace_context():
215215
use_direct_linearize.value,
216216
softmax_custom_jvp.value,
217217
disable_jit.value,
218+
enable_memories.value,
218219
debug_key_reuse.value,
219220
jax_xla_profile_version.value,
220221
# Technically this affects jaxpr->stablehlo lowering, not tracing.
@@ -971,10 +972,20 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
971972
upgrade=True,
972973
help='If True, pmap and shard_map API will be merged.')
973974

975+
def _update_jax_memories_global(val):
976+
if hasattr(jax_jit.global_state(), 'enable_memories'):
977+
jax_jit.global_state().enable_memories = val
978+
979+
def _update_jax_memories_thread_local(val):
980+
if hasattr(jax_jit.thread_local_state(), 'enable_memories'):
981+
jax_jit.thread_local_state().enable_memories = val
982+
974983
enable_memories = bool_state(
975984
'jax_enable_memories',
976985
default=True,
977986
upgrade=True,
987+
update_global_hook=_update_jax_memories_global,
988+
update_thread_local_hook=_update_jax_memories_thread_local,
978989
help=("If True, will allow fetching memory kinds available on executable "
979990
"and annotate Shardings with it."))
980991

0 commit comments

Comments
 (0)