@@ -276,11 +276,15 @@ def _check_shape(
276276
277277
278278def _prepare_for_flash_attn_or_sage_varlen (
279- batch_size : int , seq_len_q : int , attn_mask : Optional [torch .Tensor ] = None , device : Optional [torch .device ] = None
279+ batch_size : int ,
280+ seq_len_q : int ,
281+ seq_len_kv : int ,
282+ attn_mask : Optional [torch .Tensor ] = None ,
283+ device : Optional [torch .device ] = None ,
280284) -> None :
281285 seqlens_q = torch .full ((batch_size ,), seq_len_q , dtype = torch .int32 , device = device )
282286 if attn_mask is None :
283- seqlens_k = torch .full ((batch_size ,), seq_len_q , dtype = torch .int32 , device = device )
287+ seqlens_k = torch .full ((batch_size ,), seq_len_kv , dtype = torch .int32 , device = device )
284288 else :
285289 seqlens_k = attn_mask .sum (dim = 1 , dtype = torch .int32 )
286290 cu_seqlens_q = torch .zeros (batch_size + 1 , dtype = torch .int32 , device = device )
@@ -440,7 +444,9 @@ def _flash_varlen_attention(
440444
441445 if any (x is None for x in (cu_seqlens_q , cu_seqlens_k , max_seqlen_q , max_seqlen_k )):
442446 (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
443- _prepare_for_flash_attn_or_sage_varlen (batch_size , seq_len_q , attn_mask = attn_mask , device = query .device )
447+ _prepare_for_flash_attn_or_sage_varlen (
448+ batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device
449+ )
444450 )
445451 else :
446452 seqlens_k = torch .full ((batch_size ,), max_seqlen_k , dtype = torch .int32 , device = query .device )
@@ -730,7 +736,9 @@ def _sage_varlen_attention(
730736
731737 if any (x is None for x in (cu_seqlens_q , cu_seqlens_k , max_seqlen_q , max_seqlen_k )):
732738 (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
733- _prepare_for_flash_attn_or_sage_varlen (batch_size , seq_len_q , attn_mask = attn_mask , device = query .device )
739+ _prepare_for_flash_attn_or_sage_varlen (
740+ batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device
741+ )
734742 )
735743 else :
736744 seqlens_k = torch .full ((batch_size ,), max_seqlen_k , dtype = torch .int32 , device = query .device )
0 commit comments