@@ -138,14 +138,6 @@ static std::string OptionalDebugString(
138138 }
139139}
140140
141- bool FetchMemoriesFlag () {
142- auto & global_state = GlobalJitState ();
143- auto & thread_local_state = ThreadLocalJitState ();
144- CHECK (global_state.enable_memories .has_value ());
145- return thread_local_state.enable_memories .value_or (
146- *global_state.enable_memories );
147- }
148-
149141std::string ArgumentSignature::DebugString () const {
150142 auto py_object_formatter = [](std::string* out, const nb::object& o) {
151143 out->append (nb::cast<absl::string_view>(nb::str (o)));
@@ -224,7 +216,6 @@ std::string CallSignature::DebugString() const {
224216 " device: %s\n "
225217 " default_device: %s\n "
226218 " jax_enable_x64: %d\n "
227- " jax_enable_memories: %d\n "
228219 " global_extra_jit_context: %s\n "
229220 " thread_local_extra_jit_context: %s\n "
230221 " configs: %s\n " ,
@@ -234,7 +225,7 @@ std::string CallSignature::DebugString() const {
234225 absl::StrJoin (dynamic_arg_layouts, " , " , layout_formatter),
235226 absl::StrJoin (committed_args, " ," , bool_formatter),
236227 device != nullptr ? device->DebugString () : " nullptr" ,
237- OptionalDebugString (default_device), jax_enable_x64, jax_enable_memories,
228+ OptionalDebugString (default_device), jax_enable_x64,
238229 OptionalDebugString (global_extra_jit_context),
239230 OptionalDebugString (thread_local_extra_jit_context),
240231 absl::StrJoin (configs, " , " , py_object_formatter));
@@ -253,9 +244,6 @@ bool CallSignature::operator==(const CallSignature& other) const {
253244 if (jax_enable_x64 != other.jax_enable_x64 ) {
254245 return false ;
255246 }
256- if (jax_enable_memories != other.jax_enable_memories ) {
257- return false ;
258- }
259247 if (committed_args != other.committed_args ) {
260248 return false ;
261249 }
@@ -387,16 +375,12 @@ void BuildJaxjitSubmodule(nb::module_& m) {
387375 nb::class_<JitState> jit_state_ (jitlib, " JitState" );
388376 jit_state_.def_rw (" disable_jit" , &JitState::disable_jit, nb::arg ().none ());
389377 jit_state_.def_rw (" enable_x64" , &JitState::enable_x64, nb::arg ().none ());
390- jit_state_.def_rw (" enable_memories" , &JitState::enable_memories,
391- nb::arg ().none ());
392378 jit_state_.def_rw (" default_device" , &JitState::default_device,
393379 nb::arg ().none ());
394380 jit_state_.def_rw (" extra_jit_context" , &JitState::extra_jit_context,
395381 nb::arg ().none ());
396382 jit_state_.def_rw (" post_hook" , &JitState::post_hook, nb::arg ().none ());
397383
398- GetEnableMemories = +[] { return FetchMemoriesFlag (); };
399-
400384 jitlib.def (
401385 " global_state" , [&]() { return &GlobalJitState (); },
402386 nb::rv_policy::reference);
0 commit comments