Skip to content

Commit eda7506

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas MGPU] Disable XLA:GPU autotuning in attention tests
We don't care about performance of the reference impl, we only use it for correctness testing. More importantly, it works around a deadlock at compile time that sometimes happens when testing large batch sizes. PiperOrigin-RevId: 703521029
1 parent 8b65620 commit eda7506

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

tests/pallas/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -494,6 +494,7 @@ jax_multiplatform_test(
494494
srcs = ["//jax/experimental/pallas/ops/gpu:attention_mgpu.py"],
495495
enable_backends = [],
496496
enable_configs = ["gpu_h100_x32"],
497+
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
497498
tags = [
498499
"manual",
499500
"notap",
@@ -509,6 +510,7 @@ jax_multiplatform_test(
509510
srcs = ["mgpu_attention_test.py"],
510511
enable_backends = [],
511512
enable_configs = ["gpu_h100_x32"],
513+
env = {"XLA_FLAGS": "--xla_gpu_autotune_level=0"},
512514
deps = [
513515
"//jax:pallas",
514516
"//jax:pallas_experimental_gpu_ops",

0 commit comments

Comments
 (0)