33import triton
44import triton .language as tl
55
6- from flash_sparse_attn .ops .triton import seqlen_info , block_info , mask , softmax
6+ from flash_sparse_attn .ops .triton import utils , seqlen_info , block_info , mask , softmax
7+
8+
9+ fwd_base_autotune_configs = utils .get_fwd_base_autotune_configs (True )
710
811
912@triton .autotune (
10- configs = [
11- triton .Config ({"TILE_M" : 128 , "TILE_N" : 128 }, num_warps = 4 , num_stages = 1 ),
12- triton .Config ({"TILE_M" : 128 , "TILE_N" : 64 }, num_warps = 4 , num_stages = 1 ),
13- triton .Config ({"TILE_M" : 64 , "TILE_N" : 64 }, num_warps = 4 , num_stages = 1 ),
14- triton .Config ({"TILE_M" : 128 , "TILE_N" : 128 }, num_warps = 4 , num_stages = 2 ),
15- triton .Config ({"TILE_M" : 128 , "TILE_N" : 64 }, num_warps = 4 , num_stages = 2 ),
16- triton .Config ({"TILE_M" : 64 , "TILE_N" : 64 }, num_warps = 4 , num_stages = 2 ),
17- ],
18- key = ["IS_CAUSAL" , "IS_LOCAL" , "TILE_K" ],
13+ configs = fwd_base_autotune_configs ,
14+ key = utils .FWD_BASE_AUTOTUNE_KEYS ,
15+ use_cuda_graph = True ,
1916)
2017@triton .jit
2118def _fwd_base_kernel (
@@ -216,6 +213,18 @@ def _fwd_base_kernel(
216213 )
217214 else :
218215 tl .store (lse_ptrs , lse_tile , boundary_check = (0 ,))
216+
217+ # We can't get dtype of query for output here, so we initialize output to zero
218+ # # Write output as zero for proper handling
219+ # if PACK_GQA:
220+ # tl.store(
221+ # out_ptrs,
222+ # o_tile,
223+ # mask=((offs_m // QHEADS_PER_KVHEAD_PACKGQA) < actual_seqlen_q)[:, None]
224+ # & (offs_kb < head_dim)[None, :],
225+ # )
226+ # else:
227+ # tl.store(out_ptrs, o_tile, boundary_check=(0, 1))
219228 return
220229
221230 # Create query pointers
@@ -454,82 +463,151 @@ def _fwd_base_kernel(
454463 tl .store (out_ptrs , acc_o .to (q_tile .dtype ), boundary_check = (0 , 1 ))
455464
456465
457- def _flash_attn_forward (
466+ def _flash_attn_base_forward (
458467 query : torch .Tensor ,
459468 key : torch .Tensor ,
460469 value : torch .Tensor ,
461470 softmax_scale : float ,
462471 is_causal : bool = False ,
463472 window_size : Optional [Tuple [int , int ]] = None ,
464- cu_seqlens_q : Optional [torch .Tensor ] = None ,
465- cu_seqlens_k : Optional [torch .Tensor ] = None ,
466- max_seqlen_q : Optional [int ] = None ,
467- max_seqlen_k : Optional [int ] = None ,
468473 pack_gqa : bool = False ,
469474):
470- is_varlen = cu_seqlens_q is not None and cu_seqlens_k is not None
471- if not is_varlen :
472- batch_size , seqlen_q , num_heads_q , head_dim = query .shape
473- _ , seqlen_k , num_heads_kv , _ = key .shape
474- else :
475- total_seqlen_q , num_heads_q , head_dim = query .shape
476- _ , num_heads_kv , _ = key .shape
477- batch_size = cu_seqlens_q .shape [0 ] - 1
478- seqlen_q = max_seqlen_q
479- seqlen_k = max_seqlen_k
475+ batch_size , seqlen_q , num_heads_q , head_dim = query .shape
476+ _ , seqlen_k , num_heads_kv , _ = key .shape
480477
481478 is_local = window_size [0 ] is not None or window_size [1 ] is not None
482479 if is_local :
483480 window_size_left , window_size_right = window_size
484481 else :
485482 window_size_left , window_size_right = None , None
486483
487- assert query .is_cuda and key .is_cuda and value .is_cuda , (
488- "All inputs must be on CUDA device"
484+ utils .assert_fwd_base_inputs (
485+ query ,
486+ key ,
487+ value ,
488+ cu_seqlens_q = None ,
489+ cu_seqlens_k = None ,
490+ num_heads_q = num_heads_q ,
491+ num_heads_kv = num_heads_kv ,
492+ head_dim = head_dim ,
489493 )
490- assert query .dtype in [torch .float16 , torch .bfloat16 ], (
491- "Input dtype must be float16 or bfloat16"
494+
495+ softmax_scale = softmax_scale or 1.0 / (head_dim ** 0.5 )
496+
497+ out = torch .zeros_like (query )
498+ lse = torch .empty (
499+ (batch_size , num_heads_q , seqlen_q ),
500+ device = query .device ,
501+ dtype = torch .float32 ,
492502 )
493- assert query .dtype == key .dtype == value .dtype , (
494- "All inputs must have the same dtype"
503+
504+ TILE_K = max (triton .next_power_of_2 (head_dim ), 16 )
505+
506+ grid = utils .get_fwd_base_grid (
507+ batch_size = batch_size ,
508+ seqlen_q = seqlen_q ,
509+ num_heads_q = num_heads_q ,
510+ num_heads_kv = num_heads_kv ,
511+ pack_gqa = pack_gqa ,
495512 )
496- assert num_heads_q % num_heads_kv == 0 , (
497- "num_heads_q must be divisible by num_heads_kv"
513+
514+ _fwd_base_kernel [grid ](
515+ query ,
516+ key ,
517+ value ,
518+ out ,
519+ lse ,
520+ softmax_scale ,
521+ query .stride (0 ),
522+ query .stride (- 2 ),
523+ query .stride (- 3 ),
524+ key .stride (0 ),
525+ key .stride (- 2 ),
526+ key .stride (- 3 ),
527+ value .stride (0 ),
528+ value .stride (- 2 ),
529+ value .stride (- 3 ),
530+ out .stride (0 ),
531+ out .stride (- 2 ),
532+ out .stride (- 3 ),
533+ lse .stride (0 ),
534+ lse .stride (1 ),
535+ None ,
536+ None ,
537+ None ,
538+ None ,
539+ num_heads_q // num_heads_kv ,
540+ seqlen_q ,
541+ seqlen_k ,
542+ head_dim ,
543+ QHEADS_PER_KVHEAD_PACKGQA = (num_heads_q // num_heads_kv ) if pack_gqa else 1 ,
544+ TILE_K = TILE_K ,
545+ IS_CAUSAL = is_causal ,
546+ IS_LOCAL = is_local ,
547+ WINDOW_SIZE_LEFT = window_size_left ,
548+ WINDOW_SIZE_RIGHT = window_size_right ,
549+ HAS_CU_SEQLENS_Q = False ,
550+ HAS_CU_SEQLENS_K = False ,
551+ HAS_SEQUSED_Q = False ,
552+ HAS_SEQUSED_K = False ,
553+ PACK_GQA = pack_gqa ,
498554 )
499- assert head_dim % 16 == 0 , (
500- "head_dim must be a multiple of 16 for efficient memory access"
555+
556+ return out , lse , softmax_scale
557+
558+
559+ def _flash_attn_varlen_base_forward (
560+ query : torch .Tensor ,
561+ key : torch .Tensor ,
562+ value : torch .Tensor ,
563+ cu_seqlens_q : torch .Tensor ,
564+ cu_seqlens_k : torch .Tensor ,
565+ max_seqlen_q : int ,
566+ max_seqlen_k : int ,
567+ softmax_scale : float ,
568+ is_causal : bool = False ,
569+ window_size : Optional [Tuple [int , int ]] = None ,
570+ pack_gqa : bool = False ,
571+ ):
572+ total_seqlen_q , num_heads_q , head_dim = query .shape
573+ _ , num_heads_kv , _ = key .shape
574+ batch_size = cu_seqlens_q .shape [0 ] - 1
575+ seqlen_q = max_seqlen_q
576+ seqlen_k = max_seqlen_k
577+
578+ is_local = window_size [0 ] is not None or window_size [1 ] is not None
579+ if is_local :
580+ window_size_left , window_size_right = window_size
581+ else :
582+ window_size_left , window_size_right = None , None
583+
584+ utils .assert_fwd_base_inputs (
585+ query ,
586+ key ,
587+ value ,
588+ cu_seqlens_q = cu_seqlens_q ,
589+ cu_seqlens_k = cu_seqlens_k ,
590+ num_heads_q = num_heads_q ,
591+ num_heads_kv = num_heads_kv ,
592+ head_dim = head_dim ,
501593 )
502- assert head_dim <= 256 , "head_dim must be less than or equal to 256"
503- if is_varlen :
504- assert (
505- cu_seqlens_q .dtype == torch .int32 and cu_seqlens_k .dtype == torch .int32
506- ), "cu_seqlens_q and cu_seqlens_k must be of int32"
507594
508595 softmax_scale = softmax_scale or 1.0 / (head_dim ** 0.5 )
509596
510597 out = torch .zeros_like (query )
511- if not is_varlen :
512- lse = torch .empty (
513- (batch_size , num_heads_q , seqlen_q ),
514- device = query .device ,
515- dtype = torch .float32 ,
516- )
517- else :
518- lse = torch .empty (
519- (total_seqlen_q , num_heads_q ), device = query .device , dtype = torch .float32
520- )
598+ lse = torch .empty (
599+ (total_seqlen_q , num_heads_q ), device = query .device , dtype = torch .float32
600+ )
521601
522602 TILE_K = max (triton .next_power_of_2 (head_dim ), 16 )
523603
524- def grid (META ):
525- return (
526- triton .cdiv (
527- seqlen_q * (num_heads_q // num_heads_kv ) if pack_gqa else seqlen_q ,
528- META ["TILE_M" ],
529- ),
530- num_heads_kv if pack_gqa else num_heads_q ,
531- batch_size ,
532- )
604+ grid = utils .get_fwd_base_grid (
605+ batch_size = batch_size ,
606+ seqlen_q = seqlen_q ,
607+ num_heads_q = num_heads_q ,
608+ num_heads_kv = num_heads_kv ,
609+ pack_gqa = pack_gqa ,
610+ )
533611
534612 _fwd_base_kernel [grid ](
535613 query ,
@@ -538,18 +616,18 @@ def grid(META):
538616 out ,
539617 lse ,
540618 softmax_scale ,
541- query . stride ( 0 ) if not is_varlen else 0 ,
619+ 0 ,
542620 query .stride (- 2 ),
543- query .stride (- 3 ) if not is_varlen else query . stride ( 0 ),
544- key . stride ( 0 ) if not is_varlen else 0 ,
621+ query .stride (0 ),
622+ 0 ,
545623 key .stride (- 2 ),
546- key .stride (- 3 ) if not is_varlen else key . stride ( 0 ),
547- value . stride ( 0 ) if not is_varlen else 0 ,
624+ key .stride (0 ),
625+ 0 ,
548626 value .stride (- 2 ),
549- value .stride (- 3 ) if not is_varlen else value . stride ( 0 ),
550- out . stride ( 0 ) if not is_varlen else 0 ,
627+ value .stride (0 ),
628+ 0 ,
551629 out .stride (- 2 ),
552- out .stride (- 3 ) if not is_varlen else out . stride ( 0 ),
630+ out .stride (0 ),
553631 lse .stride (0 ),
554632 lse .stride (1 ),
555633 cu_seqlens_q ,
@@ -566,10 +644,11 @@ def grid(META):
566644 IS_LOCAL = is_local ,
567645 WINDOW_SIZE_LEFT = window_size_left ,
568646 WINDOW_SIZE_RIGHT = window_size_right ,
569- HAS_CU_SEQLENS_Q = is_varlen ,
570- HAS_CU_SEQLENS_K = is_varlen ,
647+ HAS_CU_SEQLENS_Q = True ,
648+ HAS_CU_SEQLENS_K = True ,
571649 HAS_SEQUSED_Q = False ,
572650 HAS_SEQUSED_K = False ,
573651 PACK_GQA = pack_gqa ,
574652 )
653+
575654 return out , lse , softmax_scale
0 commit comments