Skip to content

Commit bf0150b

Browse files
[JAX] Ignore xla_gpu_experimental_autotune_cache_mode when calculating module hash.
PiperOrigin-RevId: 698789020
1 parent 7d7a0fa commit bf0150b

File tree

2 files changed

+6
-0
lines changed

2 files changed

+6
-0
lines changed

jax/_src/cache_key.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import cast as type_cast
2222

2323
from jax._src import config
24+
from jax._src.lib import version as jaxlib_version
2425
from jax._src.lib import version_str as jaxlib_version_str
2526
from jax._src.lib import xla_client
2627
from jax._src.lib.mlir import ir
@@ -225,6 +226,8 @@ def _hash_serialized_compile_options(hash_obj, compile_options_obj,
225226
debug_options.xla_dump_hlo_as_long_text = False
226227
debug_options.xla_dump_disable_metadata = False
227228
debug_options.xla_dump_hlo_pipeline_re = ""
229+
if jaxlib_version > (0, 4, 35):
230+
debug_options.xla_gpu_experimental_autotune_cache_mode = 0
228231
# Optional way to specify the cuda install path to be used by the compiler.
229232
# This could possibly affect the cuda version compiled with, but this should
230233
# already be included in the platform information (and might not be reflected

tests/cache_key_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from jax._src import test_util as jtu
3232
from jax._src import xla_bridge
3333
from jax._src.lib import xla_client
34+
from jax._src.lib import version as jaxlib_version
3435
from jax._src.lib.mlir import ir
3536
from jax._src.mesh import Mesh
3637
from jax._src.partition_spec import PartitionSpec as P
@@ -68,6 +69,8 @@ def test_serialized_compile_options(self):
6869
debug_options.xla_dump_hlo_as_long_text = True
6970
debug_options.xla_dump_disable_metadata = True
7071
debug_options.xla_dump_hlo_pipeline_re = "xyzzy"
72+
if jaxlib_version > (0, 4, 35):
73+
debug_options.xla_gpu_experimental_autotune_cache_mode = 2
7174
hash2 = self.get_hashed_value(
7275
cache_key._hash_serialized_compile_options, compile_options
7376
)

0 commit comments

Comments
 (0)