Skip to content

Commit eba0564

Browse files
pschuhjax authors
authored andcommitted
Allow disabling compilation cache for particular runtime_types.
PiperOrigin-RevId: 640264856
1 parent 8f090b3 commit eba0564

File tree

3 files changed

+16
-10
lines changed

3 files changed

+16
-10
lines changed

jax/_src/compilation_cache.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
# Mutex to protect _cache_initialized and _cache_used.
4848
_cache_initialized_mutex = threading.Lock()
4949

50+
_UNSUPPORTED_RUNTIMES: set[str] = set()
5051

5152
def set_once_cache_used(f) -> None:
5253
"""One-time setting of _cache_used.
@@ -134,10 +135,13 @@ def _initialize_cache() -> None:
134135
logger.debug("Initialized persistent compilation cache at %s", path)
135136

136137

137-
def _get_cache() -> CacheInterface | None:
138+
def _get_cache(backend) -> CacheInterface | None:
138139
# TODO(b/289098047): consider making this an API and changing the callers of
139140
# get_executable_and_time() and put_executable_and_time() to call get_cache()
140141
# and passing the result to them.
142+
if backend.runtime_type in _UNSUPPORTED_RUNTIMES:
143+
logger.debug("_get_cache: Unsupported runtime: %s", backend.runtime_type)
144+
return None
141145
if _cache is None:
142146
_initialize_cache() # initialization is done at most once; see above
143147
return _cache
@@ -158,9 +162,9 @@ def decompress_executable(executable):
158162
return zlib.decompress(executable)
159163

160164

161-
def is_executable_in_cache(cache_key: str) -> bool:
165+
def is_executable_in_cache(backend, cache_key: str) -> bool:
162166
"""Checks if the executable is in the cache."""
163-
cache = _get_cache()
167+
cache = _get_cache(backend)
164168
if cache is None:
165169
return False
166170

@@ -175,7 +179,7 @@ def get_executable_and_time(
175179
"""Returns the cached executable and its compilation time if present, or None
176180
otherwise.
177181
"""
178-
cache = _get_cache()
182+
cache = _get_cache(backend)
179183
if cache is None:
180184
logger.debug("get_executable_and_time: cache is disabled/not initialized")
181185
return None, None
@@ -201,7 +205,7 @@ def put_executable_and_time(
201205
"""Adds the 'executable' and its compilation time to the cache, possibly
202206
evicting older entries.
203207
"""
204-
cache = _get_cache()
208+
cache = _get_cache(backend)
205209
if cache is None:
206210
logger.debug("put_executable_and_time: cache is disabled/not initialized")
207211
return

jax/_src/compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def compile_or_get_cached(
308308
computation, devices, compile_options, backend)
309309
compile_options.executable_build_options.fdo_profile = fdo_profile
310310

311-
if _is_executable_in_cache(pgle_profiled_module_key):
311+
if _is_executable_in_cache(backend, pgle_profiled_module_key):
312312
# Load PGLE profiled module from the persistent cache.
313313
cache_key = pgle_profiled_module_key
314314
if pgle_profiler is not None:
@@ -614,11 +614,11 @@ def _compile_and_write_cache(
614614
)
615615
return executable
616616

617-
def _is_executable_in_cache(cache_key) -> bool:
617+
def _is_executable_in_cache(backend, cache_key) -> bool:
618618
"""Checks if executable is presented in cache on a given key
619619
"""
620620
try:
621-
return compilation_cache.is_executable_in_cache(cache_key)
621+
return compilation_cache.is_executable_in_cache(backend, cache_key)
622622
except Exception as ex:
623623
if config.raise_persistent_cache_errors.value:
624624
raise

tests/compilation_cache_test.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,10 @@ def f(x):
231231
def test_cache_write_warning(self):
232232
f = jit(lambda x: x * x)
233233

234+
backend = xla_bridge.get_backend()
234235
with (
235236
config.raise_persistent_cache_errors(False),
236-
mock.patch.object(cc._get_cache().__class__, "put") as mock_put,
237+
mock.patch.object(cc._get_cache(backend).__class__, "put") as mock_put,
237238
warnings.catch_warnings(record=True) as w,
238239
):
239240
mock_put.side_effect = RuntimeError("test error")
@@ -252,9 +253,10 @@ def test_cache_write_warning(self):
252253
def test_cache_read_warning(self):
253254
f = jit(lambda x: x * x)
254255

256+
backend = xla_bridge.get_backend()
255257
with (
256258
config.raise_persistent_cache_errors(False),
257-
mock.patch.object(cc._get_cache().__class__, "get") as mock_get,
259+
mock.patch.object(cc._get_cache(backend).__class__, "get") as mock_get,
258260
warnings.catch_warnings(record=True) as w,
259261
):
260262
mock_get.side_effect = RuntimeError("test error")

0 commit comments

Comments
 (0)