@@ -314,13 +314,13 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
314314):
315315    seqlens_q  =  torch .full ((batch_size ,), seq_len_q , dtype = torch .int32 , device = device )
316316    seqlens_k  =  torch .full ((batch_size ,), seq_len_kv , dtype = torch .int32 , device = device )
317-     cu_seqlens_q  =  torch .zeros ( batch_size   +   1 , dtype = torch .int32 ,  device = device )
318-     cu_seqlens_k  =  torch .zeros ( batch_size   +   1 , dtype = torch .int32 ,  device = device )
319-     cu_seqlens_q [ 1 :]  =  torch .cumsum ( seqlens_q ,  dim = 0 )
320-     cu_seqlens_k [ 1 :]  =  torch .cumsum ( seqlens_k ,  dim = 0 )
317+     cu_seqlens_k  =  torch .cumsum ( seqlens_q ,  dim = 0 , dtype = torch .int32 )
318+     cu_seqlens_q  =  torch .cumsum ( seqlens_k ,  dim = 0 , dtype = torch .int32 )
319+     cu_seqlens_q  =  torch .nn . functional . pad ( cu_seqlens_q , ( 1 ,  0 ) )
320+     cu_seqlens_k  =  torch .nn . functional . pad ( cu_seqlens_k , ( 1 ,  0 ) )
321321    max_seqlen_q  =  seqlens_q .max ().item ()
322322    max_seqlen_k  =  seqlens_k .max ().item ()
323-     return  (seqlens_q ,  seqlens_k ), ( cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k )
323+     return  (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k )
324324
325325
326326def  _prepare_for_flash_attn_or_sage_varlen_with_mask (
@@ -331,13 +331,11 @@ def _prepare_for_flash_attn_or_sage_varlen_with_mask(
331331):
332332    seqlens_q  =  torch .full ((batch_size ,), seq_len_q , dtype = torch .int32 , device = device )
333333    seqlens_k  =  attn_mask .sum (dim = 1 , dtype = torch .int32 )
334-     cu_seqlens_q  =  torch .zeros (batch_size  +  1 , dtype = torch .int32 , device = device )
335-     cu_seqlens_k  =  torch .zeros (batch_size  +  1 , dtype = torch .int32 , device = device )
336-     cu_seqlens_q [1 :] =  torch .cumsum (seqlens_q , dim = 0 )
337-     cu_seqlens_k [1 :] =  torch .cumsum (seqlens_k , dim = 0 )
334+     cu_seqlens_q  =  torch .nn .functional .pad (torch .cumsum (seqlens_q , dim = 0 , dtype = torch .int32 ), (1 , 0 ))
335+     cu_seqlens_k  =  torch .nn .functional .pad (torch .cumsum (seqlens_k , dim = 0 , dtype = torch .int32 ), (1 , 0 ))
338336    max_seqlen_q  =  seqlens_q .max ().item ()
339337    max_seqlen_k  =  seqlens_k .max ().item ()
340-     return  (seqlens_q ,  seqlens_k ), ( cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k )
338+     return  (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k )
341339
342340
343341def  _prepare_for_flash_attn_or_sage_varlen (
@@ -496,30 +494,18 @@ def _flash_varlen_attention(
496494        attn_mask  =  _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
497495
498496    if  any (x  is  None  for  x  in  (cu_seqlens_q , cu_seqlens_k , max_seqlen_q , max_seqlen_k )):
499-         (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) =  (
500-             _prepare_for_flash_attn_or_sage_varlen (
501-                 batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device 
502-             )
497+         (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) =  _prepare_for_flash_attn_or_sage_varlen (
498+             batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device 
503499        )
504-     else :
505-         seqlens_k  =  torch .full ((batch_size ,), max_seqlen_k , dtype = torch .int32 , device = query .device )
506-         cu_seqlens_q  =  cu_seqlens_q .to (dtype = torch .int32 , device = query .device )
507-         cu_seqlens_k  =  cu_seqlens_k .to (dtype = torch .int32 , device = query .device )
508- 
509-     key_valid , value_valid  =  [], []
510-     for  b  in  range (batch_size ):
511-         valid_len  =  seqlens_k [b ]
512-         key_valid .append (key [b , :valid_len ])
513-         value_valid .append (value [b , :valid_len ])
514500
515-     query_packed  =  query .flatten (0 , 1 )
516-     key_packed  =  torch .cat (key_valid , dim = 0 )
517-     value_packed  =  torch .cat (value_valid , dim = 0 )
501+     cu_seqlens_q  =  cu_seqlens_q .to (dtype = torch .int32 , device = query .device )
502+     cu_seqlens_k  =  cu_seqlens_k .to (dtype = torch .int32 , device = query .device )
518503
504+     query , key , value  =  (x .flatten (0 , 1 ) for  x  in  (query , key , value ))
519505    out  =  flash_attn_varlen_func (
520-         q = query_packed ,
521-         k = key_packed ,
522-         v = value_packed ,
506+         q = query ,
507+         k = key ,
508+         v = value ,
523509        cu_seqlens_q = cu_seqlens_q ,
524510        cu_seqlens_k = cu_seqlens_k ,
525511        max_seqlen_q = max_seqlen_q ,
@@ -601,30 +587,18 @@ def _flash_varlen_attention_3(
601587        attn_mask  =  _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
602588
603589    if  any (x  is  None  for  x  in  (cu_seqlens_q , cu_seqlens_k , max_seqlen_q , max_seqlen_k )):
604-         (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) =  (
605-             _prepare_for_flash_attn_or_sage_varlen (
606-                 batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device 
607-             )
590+         (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) =  _prepare_for_flash_attn_or_sage_varlen (
591+             batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device 
608592        )
609-     else :
610-         seqlens_k  =  torch .full ((batch_size ,), max_seqlen_k , dtype = torch .int32 , device = query .device )
611-         cu_seqlens_q  =  cu_seqlens_q .to (dtype = torch .int32 , device = query .device )
612-         cu_seqlens_k  =  cu_seqlens_k .to (dtype = torch .int32 , device = query .device )
613- 
614-     key_valid , value_valid  =  [], []
615-     for  b  in  range (batch_size ):
616-         valid_len  =  seqlens_k [b ]
617-         key_valid .append (key [b , :valid_len ])
618-         value_valid .append (value [b , :valid_len ])
619593
620-     query_packed  =  query .flatten (0 , 1 )
621-     key_packed  =  torch .cat (key_valid , dim = 0 )
622-     value_packed  =  torch .cat (value_valid , dim = 0 )
594+     cu_seqlens_q  =  cu_seqlens_q .to (dtype = torch .int32 , device = query .device )
595+     cu_seqlens_k  =  cu_seqlens_k .to (dtype = torch .int32 , device = query .device )
623596
597+     query , key , value  =  (x .flatten (0 , 1 ) for  x  in  (query , key , value ))
624598    out , lse , * _  =  flash_attn_3_varlen_func (
625-         q = query_packed ,
626-         k = key_packed ,
627-         v = value_packed ,
599+         q = query ,
600+         k = key ,
601+         v = value ,
628602        cu_seqlens_q = cu_seqlens_q ,
629603        cu_seqlens_k = cu_seqlens_k ,
630604        max_seqlen_q = max_seqlen_q ,
@@ -958,30 +932,18 @@ def _sage_varlen_attention(
958932        attn_mask  =  _normalize_attn_mask (attn_mask , batch_size , seq_len_kv )
959933
960934    if  any (x  is  None  for  x  in  (cu_seqlens_q , cu_seqlens_k , max_seqlen_q , max_seqlen_k )):
961-         (_ , seqlens_k ), (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) =  (
962-             _prepare_for_flash_attn_or_sage_varlen (
963-                 batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device 
964-             )
935+         (cu_seqlens_q , cu_seqlens_k ), (max_seqlen_q , max_seqlen_k ) =  _prepare_for_flash_attn_or_sage_varlen (
936+             batch_size , seq_len_q , seq_len_kv , attn_mask = attn_mask , device = query .device 
965937        )
966-     else :
967-         seqlens_k  =  torch .full ((batch_size ,), max_seqlen_k , dtype = torch .int32 , device = query .device )
968-         cu_seqlens_q  =  cu_seqlens_q .to (dtype = torch .int32 , device = query .device )
969-         cu_seqlens_k  =  cu_seqlens_k .to (dtype = torch .int32 , device = query .device )
970938
971-     key_valid , value_valid  =  [], []
972-     for  b  in  range (batch_size ):
973-         valid_len  =  seqlens_k [b ]
974-         key_valid .append (key [b , :valid_len ])
975-         value_valid .append (value [b , :valid_len ])
976- 
977-     query_packed  =  query .flatten (0 , 1 )
978-     key_packed  =  torch .cat (key_valid , dim = 0 )
979-     value_packed  =  torch .cat (value_valid , dim = 0 )
939+     cu_seqlens_q  =  cu_seqlens_q .to (dtype = torch .int32 , device = query .device )
940+     cu_seqlens_k  =  cu_seqlens_k .to (dtype = torch .int32 , device = query .device )
980941
942+     query , key , value  =  (x .flatten (0 , 1 ) for  x  in  (query , key , value ))
981943    out  =  sageattn_varlen (
982-         q = query_packed ,
983-         k = key_packed ,
984-         v = value_packed ,
944+         q = query ,
945+         k = key ,
946+         v = value ,
985947        cu_seqlens_q = cu_seqlens_q ,
986948        cu_seqlens_k = cu_seqlens_k ,
987949        max_seqlen_q = max_seqlen_q ,
0 commit comments