Skip to content

Commit db56a59

Browse files
[BugFix] Fix FA3 IMA with FULL_AND_PIECEWISE and cascade attention (default) (vllm-project#28702)
1 parent 9324e10 commit db56a59

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

tests/kernels/attention/test_cascade_flash_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def test_cascade(
170170
logits_soft_cap=soft_cap if soft_cap is not None else 0,
171171
block_table=block_tables,
172172
common_prefix_len=common_prefix_len,
173+
max_num_splits=0, # no max
173174
fa_version=fa_version,
174175
)
175176

vllm/v1/attention/backends/flash_attn.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,7 @@ def forward(
704704
logits_soft_cap=self.logits_soft_cap,
705705
block_table=attn_metadata.block_table,
706706
common_prefix_len=attn_metadata.common_prefix_len,
707+
max_num_splits=attn_metadata.max_num_splits,
707708
fa_version=self.vllm_flash_attn_version,
708709
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
709710
suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
@@ -950,6 +951,7 @@ def cascade_attention(
950951
logits_soft_cap: float,
951952
block_table: torch.Tensor,
952953
common_prefix_len: int,
954+
max_num_splits: int,
953955
fa_version: int,
954956
prefix_scheduler_metadata: torch.Tensor | None = None,
955957
suffix_scheduler_metadata: torch.Tensor | None = None,
@@ -994,7 +996,7 @@ def cascade_attention(
994996
# s_aux is incorporated into prefix_lse inside the GPU kernel,
995997
# enabling its effect during the final attention merge.
996998
s_aux=s_aux,
997-
num_splits=1 if vllm_is_batch_invariant() else 0,
999+
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
9981000
)
9991001

10001002
descale_shape = (cu_query_lens.shape[0] - 1, key_cache.shape[-2])
@@ -1019,7 +1021,7 @@ def cascade_attention(
10191021
q_descale=q_descale.expand(descale_shape) if q_descale is not None else None,
10201022
k_descale=k_descale.expand(descale_shape) if k_descale is not None else None,
10211023
v_descale=v_descale.expand(descale_shape) if v_descale is not None else None,
1022-
num_splits=1 if vllm_is_batch_invariant() else 0,
1024+
num_splits=1 if vllm_is_batch_invariant() else max_num_splits,
10231025
)
10241026

10251027
# Merge prefix and suffix outputs, and store the result in output.

0 commit comments

Comments
 (0)