Skip to content

Commit ba6b1fd

Browse files
committed
address lint and typecheck
1 parent bf4c3a7 commit ba6b1fd

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

jax/_src/cudnn/fused_attention_stablehlo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def check_is_flash_attention(
369369
)
370370

371371
if is_packed and cudnn_version < 90600:
372-
raise NotImplementedError(f"Packed layout requires cudnn version >= 9.6.")
372+
raise NotImplementedError("Packed layout requires cudnn version >= 9.6.")
373373

374374
def check_cudnn_version():
375375
# check if cuDNN is installed

tests/fused_attention_stablehlo_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,13 @@ def test_sdpa_large_head_size(self):
616616

617617
@jtu.run_on_devices("cuda")
618618
def test_sdpa_packed_layout(self):
619+
if jax.device_count() < 4:
620+
self.skipTest("Requires more than 4 devices.")
621+
try:
622+
cudnn_version = check_cudnn_version()
623+
except RuntimeError as e:
624+
self.skipTest(str(e))
625+
return
619626
if cudnn_version < 90600:
620627
self.skipTest("Requires >= cuDNN 9.6.0")
621628
k1, k2, k3, k4 = jax.random.split(jax.random.key(0), 4)

0 commit comments

Comments
 (0)