@@ -305,6 +305,7 @@ def dispatch_attention_fn(
305305 * ,
306306 backend : Optional [AttentionBackendName ] = None ,
307307 parallel_config : Optional ["ParallelConfig" ] = None ,
308+ seq_lens : Optional [torch .Tensor ] = None ,
308309) -> torch .Tensor :
309310 attention_kwargs = attention_kwargs or {}
310311
@@ -327,6 +328,8 @@ def dispatch_attention_fn(
327328 ** attention_kwargs ,
328329 "_parallel_config" : parallel_config ,
329330 }
331+ if seq_lens is not None :
332+ kwargs ["seq_lens" ] = seq_lens
330333 if is_torch_version (">=" , "2.5.0" ):
331334 kwargs ["enable_gqa" ] = enable_gqa
332335
@@ -1400,18 +1403,29 @@ def _flash_varlen_attention(
14001403 is_causal : bool = False ,
14011404 return_lse : bool = False ,
14021405 _parallel_config : Optional ["ParallelConfig" ] = None ,
1406+ seq_lens : Optional [torch .Tensor ] = None ,
14031407) -> torch .Tensor :
14041408 batch_size , seq_len_q , _ , _ = query .shape
14051409 _ , seq_len_kv , _ , _ = key .shape
14061410
1407- if attn_mask is not None :
1408- attn_mask = _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
1411+ if seq_lens is not None :
1412+ seq_lens = seq_lens .to (query .device )
1413+ # use the same lengths for Q and KV
1414+ seqlens_k = seq_lens
1415+ cu_seqlens_q = torch .cat ([seq_lens .new_zeros (1 ), seq_lens .cumsum (0 )], dim = 0 ).to (torch .int32 )
1416+ cu_seqlens_k = cu_seqlens_q
1417+ max_seqlen_q = int (seq_lens .max ().item ())
1418+ max_seqlen_k = max_seqlen_q
1419+ attn_mask = None # varlen uses lengths
1420+ else :
1421+ if attn_mask is not None :
1422+ attn_mask = _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
14091423
1410- (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
1411- _prepare_for_flash_attn_or_sage_varlen (
1412- batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device
1424+ (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
1425+ _prepare_for_flash_attn_or_sage_varlen (
1426+ batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device
1427+ )
14131428 )
1414- )
14151429
14161430 key_valid , value_valid = [], []
14171431 for b in range (batch_size ):
@@ -1521,18 +1535,28 @@ def _flash_varlen_attention_3(
15211535 is_causal : bool = False ,
15221536 return_lse : bool = False ,
15231537 _parallel_config : Optional ["ParallelConfig" ] = None ,
1538+ seq_lens : Optional [torch .Tensor ] = None ,
15241539) -> torch .Tensor :
15251540 batch_size , seq_len_q , _ , _ = query .shape
15261541 _ , seq_len_kv , _ , _ = key .shape
15271542
1528- if attn_mask is not None :
1529- attn_mask = _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
1543+ if seq_lens is not None :
1544+ seq_lens = seq_lens .to (query .device )
1545+ seqlens_k = seq_lens
1546+ cu_seqlens_q = torch .cat ([seq_lens .new_zeros (1 ), seq_lens .cumsum (0 )], dim = 0 ).to (torch .int32 )
1547+ cu_seqlens_k = cu_seqlens_q
1548+ max_seqlen_q = int (seq_lens .max ().item ())
1549+ max_seqlen_k = max_seqlen_q
1550+ attn_mask = None # varlen uses lengths
1551+ else :
1552+ if attn_mask is not None :
1553+ attn_mask = _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
15301554
1531- (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
1532- _prepare_for_flash_attn_or_sage_varlen (
1533- batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device
1555+ (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
1556+ _prepare_for_flash_attn_or_sage_varlen (
1557+ batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device
1558+ )
15341559 )
1535- )
15361560
15371561 key_valid , value_valid = [], []
15381562 for b in range (batch_size ):
@@ -2023,21 +2047,31 @@ def _sage_varlen_attention(
20232047 scale : Optional [float ] = None ,
20242048 return_lse : bool = False ,
20252049 _parallel_config : Optional ["ParallelConfig" ] = None ,
2050+ seq_lens : Optional [torch .Tensor ] = None ,
20262051) -> torch .Tensor :
20272052 if return_lse :
20282053 raise ValueError ("Sage varlen backend does not support setting `return_lse=True`." )
20292054
20302055 batch_size , seq_len_q , _ , _ = query .shape
20312056 _ , seq_len_kv , _ , _ = key .shape
20322057
2033- if attn_mask is not None :
2034- attn_mask = _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
2058+ if seq_lens is not None :
2059+ seq_lens = seq_lens .to (query .device )
2060+ seqlens_k = seq_lens
2061+ cu_seqlens_q = torch .cat ([seq_lens .new_zeros (1 ), seq_lens .cumsum (0 )], dim = 0 ).to (torch .int32 )
2062+ cu_seqlens_k = cu_seqlens_q
2063+ max_seqlen_q = int (seq_lens .max ().item ())
2064+ max_seqlen_k = max_seqlen_q
2065+ attn_mask = None # varlen uses lengths
2066+ else :
2067+ if attn_mask is not None :
2068+ attn_mask = _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
20352069
2036- (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
2037- _prepare_for_flash_attn_or_sage_varlen (
2038- batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device
2070+ (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) = (
2071+ _prepare_for_flash_attn_or_sage_varlen (
2072+ batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device
2073+ )
20392074 )
2040- )
20412075
20422076 key_valid , value_valid = [], []
20432077 for b in range (batch_size ):
0 commit comments