Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions jax/_src/cudnn/fused_attention_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,8 @@ def check_is_flash_attention(
else:
# bf16/fp16 attention conditions
# Check the head dim.
is_on_hopper = is_cuda_compute_capability_equal("9.0")
H_max = 256 if is_on_hopper else 128
is_hopper_or_later = check_compute_capability("9.0")
H_max = 256 if is_hopper_or_later else 128
# check if multi-head latent attention is needed
is_mla = qH != vH
if not (qH <= H_max and qH % 8 == 0):
Expand Down