Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion jax/_src/cudnn/fused_attention_stablehlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,8 @@ def check_is_flash_attention(
# bf16/fp16 attention conditions
# Check the head dim.
is_on_hopper = is_cuda_compute_capability_equal("9.0")
H_max = 256 if is_on_hopper else 128
is_on_blackwell = is_cuda_compute_capability_equal("10.0")
H_max = 256 if (is_on_hopper or is_on_blackwell) else 128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To make this check more future-proof for upcoming GPU architectures, it would be better to check for a minimum compute capability rather than listing specific architectures. Assuming all architectures from Hopper (9.0) onwards will support H_max = 256, you can simplify this logic using check_compute_capability. This also aligns with how other capability checks are performed in this file (e.g., for is_packed and is_mla).

Suggested change
is_on_hopper = is_cuda_compute_capability_equal("9.0")
H_max = 256 if is_on_hopper else 128
is_on_blackwell = is_cuda_compute_capability_equal("10.0")
H_max = 256 if (is_on_hopper or is_on_blackwell) else 128
is_hopper_or_later = check_compute_capability("9.0")
H_max = 256 if is_hopper_or_later else 128

# check if multi-head latent attention is needed
is_mla = qH != vH
if not (qH <= H_max and qH % 8 == 0):
Expand Down