Skip to content

Commit 8d7f665

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Delete enable_memories code in C++ since that flag is always True and cannot be turned off now.
PiperOrigin-RevId: 707298305
1 parent e68461f commit 8d7f665

File tree

9 files changed

+2
-42
lines changed

9 files changed

+2
-42
lines changed

xla/python/jax_jit.cc

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -138,14 +138,6 @@ static std::string OptionalDebugString(
138138
}
139139
}
140140

141-
bool FetchMemoriesFlag() {
142-
auto& global_state = GlobalJitState();
143-
auto& thread_local_state = ThreadLocalJitState();
144-
CHECK(global_state.enable_memories.has_value());
145-
return thread_local_state.enable_memories.value_or(
146-
*global_state.enable_memories);
147-
}
148-
149141
std::string ArgumentSignature::DebugString() const {
150142
auto py_object_formatter = [](std::string* out, const nb::object& o) {
151143
out->append(nb::cast<absl::string_view>(nb::str(o)));
@@ -224,7 +216,6 @@ std::string CallSignature::DebugString() const {
224216
"device: %s\n"
225217
"default_device: %s\n"
226218
"jax_enable_x64: %d\n"
227-
"jax_enable_memories: %d\n"
228219
"global_extra_jit_context: %s\n"
229220
"thread_local_extra_jit_context: %s\n"
230221
"configs: %s\n",
@@ -234,7 +225,7 @@ std::string CallSignature::DebugString() const {
234225
absl::StrJoin(dynamic_arg_layouts, ", ", layout_formatter),
235226
absl::StrJoin(committed_args, ",", bool_formatter),
236227
device != nullptr ? device->DebugString() : "nullptr",
237-
OptionalDebugString(default_device), jax_enable_x64, jax_enable_memories,
228+
OptionalDebugString(default_device), jax_enable_x64,
238229
OptionalDebugString(global_extra_jit_context),
239230
OptionalDebugString(thread_local_extra_jit_context),
240231
absl::StrJoin(configs, ", ", py_object_formatter));
@@ -253,9 +244,6 @@ bool CallSignature::operator==(const CallSignature& other) const {
253244
if (jax_enable_x64 != other.jax_enable_x64) {
254245
return false;
255246
}
256-
if (jax_enable_memories != other.jax_enable_memories) {
257-
return false;
258-
}
259247
if (committed_args != other.committed_args) {
260248
return false;
261249
}
@@ -387,16 +375,12 @@ void BuildJaxjitSubmodule(nb::module_& m) {
387375
nb::class_<JitState> jit_state_(jitlib, "JitState");
388376
jit_state_.def_rw("disable_jit", &JitState::disable_jit, nb::arg().none());
389377
jit_state_.def_rw("enable_x64", &JitState::enable_x64, nb::arg().none());
390-
jit_state_.def_rw("enable_memories", &JitState::enable_memories,
391-
nb::arg().none());
392378
jit_state_.def_rw("default_device", &JitState::default_device,
393379
nb::arg().none());
394380
jit_state_.def_rw("extra_jit_context", &JitState::extra_jit_context,
395381
nb::arg().none());
396382
jit_state_.def_rw("post_hook", &JitState::post_hook, nb::arg().none());
397383

398-
GetEnableMemories = +[] { return FetchMemoriesFlag(); };
399-
400384
jitlib.def(
401385
"global_state", [&]() { return &GlobalJitState(); },
402386
nb::rv_policy::reference);

xla/python/jax_jit.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ struct JitState {
5959

6060
std::optional<bool> disable_jit;
6161
std::optional<bool> enable_x64;
62-
std::optional<bool> enable_memories;
6362

6463
// Used to manually set the default device jax should use. May be unset even
6564
// in global state, indicating there is no manual override.
@@ -205,7 +204,6 @@ struct CallSignature {
205204
// This is not the case for PMAP, and is set to `nullptr`.
206205
xla::PjRtDevice* device = nullptr;
207206
bool jax_enable_x64;
208-
bool jax_enable_memories = false;
209207

210208
// For JIT on PJIT, we need to fallback to python whenever default_device
211209
// changes.

xla/python/pjit.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,6 @@ absl::Status PjitFunction::ComputeCallSignature(
804804

805805
signature.default_device = GetDefaultDevice();
806806
signature.jax_enable_x64 = jax_enable_x64;
807-
signature.jax_enable_memories = GetEnableMemories();
808807

809808
auto& dynamic_arg_signatures = signature.dynamic_arg_signatures;
810809
dynamic_arg_signatures.reserve(flat_dynamic_args.size());

xla/python/py_array.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ PyArray PyArray::MakeFromSingleDeviceArray(
500500
auto dtype = IfrtDtypeToDtypeWithTokenCanonicalization(key.dtype).value();
501501
const ifrt::MemoryKind memory_kind = ifrt_array->sharding().memory_kind();
502502
nb::object py_memory_kind =
503-
(jax::GetEnableMemories() && memory_kind.memory_kind().has_value())
503+
(memory_kind.memory_kind().has_value())
504504
? nb::object(nb::str(memory_kind.memory_kind()->data(),
505505
memory_kind.memory_kind()->size()))
506506
: nb::none();

xla/python/py_device_list.cc

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,6 @@ void PyDeviceList::PopulateMemoryKindInfoForDuckTypedDevices() {
396396
}
397397

398398
absl::StatusOr<nb::tuple> PyDeviceList::MemoryKinds() {
399-
if (!GetEnableMemories()) {
400-
return nb::tuple();
401-
}
402399
if (!memory_kind_info_.has_value()) {
403400
PopulateMemoryKindInfo();
404401
}
@@ -409,9 +406,6 @@ absl::StatusOr<nb::tuple> PyDeviceList::MemoryKinds() {
409406
}
410407

411408
absl::StatusOr<nb::object> PyDeviceList::DefaultMemoryKind() {
412-
if (!GetEnableMemories()) {
413-
return nb::none();
414-
}
415409
if (!memory_kind_info_.has_value()) {
416410
PopulateMemoryKindInfo();
417411
}

xla/python/sharding.cc

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,6 @@ namespace jax {
4646

4747
namespace nb = nanobind;
4848

49-
bool (*GetEnableMemories)() = +[] {
50-
static bool fetch_memory_kind_on_executable = [] {
51-
char* v = getenv("JAX_ENABLE_MEMORIES");
52-
if (v == nullptr || *v == '\0') {
53-
return false;
54-
}
55-
return true;
56-
}();
57-
return fetch_memory_kind_on_executable;
58-
};
59-
6049
nb::object CheckAndCanonicalizeMemoryKind(
6150
nb::object memory_kind,
6251
const xla::nb_class_ptr<PyDeviceList>& device_list) {

xla/python/sharding.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,6 @@ class Sharding {
5252
std::optional<int> num_devices_;
5353
};
5454

55-
extern bool (*GetEnableMemories)();
56-
5755
// Checks if the memory kind is valid, and canonicalizes the
5856
// memory kind to default memory on backends that support memories.
5957
nanobind::object CheckAndCanonicalizeMemoryKind(

xla/python/xla_client_test.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
xla_client._xla.jax_jit.set_thread_local_state_initialization_callback(
5353
lambda: None
5454
)
55-
xla_client._xla.jax_jit.global_state().enable_memories = False
5655

5756
bfloat16 = xla_client.bfloat16
5857
# TODO(reedwm): Uncomment once the minimum ml_dtypes in JAX is >= 0.5.0.

xla/python/xla_extension/jax_jit.pyi

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ Device = xla_extension.Device
2727
class JitState:
2828
disable_jit: Optional[bool]
2929
enable_x64: Optional[bool]
30-
enable_memories: Optional[bool]
3130
default_device: Optional[Any]
3231
extra_jit_context: Optional[Any]
3332
post_hook: Optional[Callable[..., Any]]

0 commit comments

Comments
 (0)