Skip to content

Commit 16a5607

Browse files
Use xla_extension_version instead of jaxlib_version
PiperOrigin-RevId: 700265297
1 parent 024e331 commit 16a5607

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

jax/_src/cache_key.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from typing import cast as type_cast
2222

2323
from jax._src import config
24-
from jax._src.lib import version as jaxlib_version
2524
from jax._src.lib import version_str as jaxlib_version_str
2625
from jax._src.lib import xla_client
26+
from jax._src.lib import xla_extension_version
2727
from jax._src.lib.mlir import ir
2828
from jax._src.lib.mlir import passmanager as pm
2929
import numpy as np
@@ -226,8 +226,11 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj,
226226
debug_options.xla_dump_hlo_as_long_text = False
227227
debug_options.xla_dump_disable_metadata = False
228228
debug_options.xla_dump_hlo_pipeline_re = ""
229-
if jaxlib_version > (0, 4, 35):
229+
230+
# "Requires jaxlib 0.4.36+"
231+
if xla_extension_version > 296:
230232
debug_options.xla_gpu_experimental_autotune_cache_mode = 0
233+
231234
# Optional way to specify the cuda install path to be used by the compiler.
232235
# This could possibly affect the cuda version compiled with, but this should
233236
# already be included in the platform information (and might not be reflected

0 commit comments

Comments
 (0)