@@ -704,6 +704,7 @@ def forward(
704704 logits_soft_cap = self .logits_soft_cap ,
705705 block_table = attn_metadata .block_table ,
706706 common_prefix_len = attn_metadata .common_prefix_len ,
707+ max_num_splits = attn_metadata .max_num_splits ,
707708 fa_version = self .vllm_flash_attn_version ,
708709 prefix_scheduler_metadata = attn_metadata .prefix_scheduler_metadata ,
709710 suffix_scheduler_metadata = attn_metadata .scheduler_metadata ,
@@ -950,6 +951,7 @@ def cascade_attention(
950951 logits_soft_cap : float ,
951952 block_table : torch .Tensor ,
952953 common_prefix_len : int ,
954+ max_num_splits : int ,
953955 fa_version : int ,
954956 prefix_scheduler_metadata : torch .Tensor | None = None ,
955957 suffix_scheduler_metadata : torch .Tensor | None = None ,
@@ -994,7 +996,7 @@ def cascade_attention(
994996 # s_aux is incorporated into prefix_lse inside the GPU kernel,
995997 # enabling its effect during the final attention merge.
996998 s_aux = s_aux ,
997- num_splits = 1 if vllm_is_batch_invariant () else 0 ,
999+ num_splits = 1 if vllm_is_batch_invariant () else max_num_splits ,
9981000 )
9991001
10001002 descale_shape = (cu_query_lens .shape [0 ] - 1 , key_cache .shape [- 2 ])
@@ -1019,7 +1021,7 @@ def cascade_attention(
10191021 q_descale = q_descale .expand (descale_shape ) if q_descale is not None else None ,
10201022 k_descale = k_descale .expand (descale_shape ) if k_descale is not None else None ,
10211023 v_descale = v_descale .expand (descale_shape ) if v_descale is not None else None ,
1022- num_splits = 1 if vllm_is_batch_invariant () else 0 ,
1024+ num_splits = 1 if vllm_is_batch_invariant () else max_num_splits ,
10231025 )
10241026
10251027 # Merge prefix and suffix outputs, and store the result in output.
0 commit comments