Skip to content

Commit cca9afa

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Delete enable_memories code in C++ since that flag is always True and cannot be turned off now.
PiperOrigin-RevId: 707298305
1 parent cce4066 commit cca9afa

File tree

1 file changed

+0
-9
lines changed

1 file changed

+0
-9
lines changed

jax/_src/config.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,6 @@ def trace_context():
214214
sharding_in_types.value,
215215
use_direct_linearize.value,
216216
softmax_custom_jvp.value,
217-
enable_memories.value,
218217
disable_jit.value,
219218
debug_key_reuse.value,
220219
jax_xla_profile_version.value,
@@ -972,18 +971,10 @@ def enum_flag(name, default, *args, **kwargs) -> Flag[str]:
972971
upgrade=True,
973972
help='If True, pmap and shard_map API will be merged.')
974973

975-
def _update_jax_memories_global(val):
976-
jax_jit.global_state().enable_memories = val
977-
978-
def _update_jax_memories_thread_local(val):
979-
jax_jit.thread_local_state().enable_memories = val
980-
981974
enable_memories = bool_state(
982975
'jax_enable_memories',
983976
default=True,
984977
upgrade=True,
985-
update_global_hook=_update_jax_memories_global,
986-
update_thread_local_hook=_update_jax_memories_thread_local,
987978
help=("If True, will allow fetching memory kinds available on executable "
988979
"and annotate Shardings with it."))
989980

0 commit comments

Comments
 (0)