1
+ import copy
1
2
import torch
2
3
import triton
3
4
import pytest
@@ -194,6 +195,8 @@ class AttentionConfig:
194
195
num_warps : gl .constexpr
195
196
196
197
SPLIT_D_FACTOR : gl .constexpr
198
+ SPLIT_EXP_FACTOR : gl .constexpr
199
+ SPLIT_QK_LOAD_FACTOR : gl .constexpr
197
200
SPLIT_M : gl .constexpr
198
201
SPLIT_D : gl .constexpr
199
202
@@ -218,7 +221,7 @@ class AttentionConfig:
218
221
use_ffma2_scale_rowmax : gl .constexpr
219
222
220
223
def __init__ (self , qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , GROUP_SIZE_N , NUM_SMS , STAGE , dtype ,
221
- num_warps , SPLIT_D_FACTOR ):
224
+ num_warps ):
222
225
self .qk_scale = qk_scale
223
226
self .Z = Z
224
227
self .H = H
@@ -232,7 +235,9 @@ def __init__(self, qk_scale, Z, H, N_CTX, BLOCK_M, BLOCK_N, HEAD_DIM, GROUP_SIZE
232
235
self .dtype = gl .constexpr (dtype )
233
236
self .num_warps = gl .constexpr (num_warps )
234
237
235
- self .SPLIT_D_FACTOR = gl .constexpr (SPLIT_D_FACTOR )
238
+ self .SPLIT_D_FACTOR = gl .constexpr (2 )
239
+ self .SPLIT_EXP_FACTOR = 256 // HEAD_DIM
240
+ self .SPLIT_QK_LOAD_FACTOR = gl .constexpr (2 if STAGE == 1 else 1 )
236
241
self .SPLIT_M = gl .constexpr (self .BLOCK_M // 2 )
237
242
self .SPLIT_D = gl .constexpr (self .HEAD_DIM // self .SPLIT_D_FACTOR )
238
243
@@ -488,6 +493,44 @@ def _borrow_s_for_epilogue(config, s_tmem):
488
493
return m_i_tmem , l_i_tmem
489
494
490
495
496
+ @gl .constexpr_function
497
+ def _get_split_n_layout (layout , SPLIT_FACTOR : gl .constexpr = 2 ):
498
+ layout = copy .deepcopy (layout )
499
+ layout .size_per_thread [1 ] //= SPLIT_FACTOR
500
+ return layout
501
+
502
+
503
+ @gluon .jit
504
+ def _split_n (x , SPLIT_FACTOR : gl .constexpr = 2 ):
505
+ if SPLIT_FACTOR == 1 :
506
+ return (x , )
507
+ else :
508
+ layout : gl .constexpr = _get_split_n_layout (x .type .layout )
509
+ x0 , x1 = x .reshape ([x .shape [0 ], 2 , x .shape [1 ] // 2 ]).permute (0 , 2 , 1 ).split ()
510
+ x0 = gl .convert_layout (x0 , layout , assert_trivial = True )
511
+ x1 = gl .convert_layout (x1 , layout , assert_trivial = True )
512
+ return _split_n (x0 , SPLIT_FACTOR // 2 ) + _split_n (x1 , SPLIT_FACTOR // 2 )
513
+
514
+
515
+ @gl .constexpr_function
516
+ def _get_join_n_layout (layout , SPLIT_FACTOR : gl .constexpr = 2 ):
517
+ layout = copy .deepcopy (layout )
518
+ layout .size_per_thread [1 ] *= SPLIT_FACTOR
519
+ return layout
520
+
521
+
522
+ @gluon .jit
523
+ def _join_n (xs ):
524
+ if len (xs ) == 1 :
525
+ return xs [0 ]
526
+ else :
527
+ x0 = _join_n (xs [:len (xs ) // 2 ])
528
+ x1 = _join_n (xs [len (xs ) // 2 :])
529
+ layout : gl .constexpr = _get_join_n_layout (x0 .type .layout )
530
+ x = gl .join (x0 , x1 ).permute (0 , 2 , 1 ).reshape ([x0 .shape [0 ], x0 .shape [1 ] * 2 ])
531
+ return gl .convert_layout (x , layout , assert_trivial = True )
532
+
533
+
491
534
@gluon .jit
492
535
def _attn_fwd_load (config , chnls , descs , M , STAGE : gl .constexpr ):
493
536
q_chnl , kv_chnl , o_chnl , epi_chnl , s0_chnl , s1_chnl , c0_chnl , c1_chnl , exp_turnstile = chnls
@@ -609,6 +652,28 @@ def _apply_causal_mask(qk, col_limit_right):
609
652
return gl .map_elementwise (_mask_scalar , qk , col_limit_right , s , i )
610
653
611
654
655
+ @gluon .jit
656
+ def _compute_and_store_exp2 (config , qk , p_tmem ):
657
+ SIZE : gl .constexpr = p_tmem .shape [1 ] // config .SPLIT_EXP_FACTOR
658
+ qks = _split_n (qk , config .SPLIT_EXP_FACTOR )
659
+ ps = ()
660
+ for i in gl .static_range (config .SPLIT_EXP_FACTOR ):
661
+ p = gl .exp2 (qks [i ])
662
+ p_tmem .slice (i * SIZE , SIZE ).store (p .to (config .dtype ))
663
+ ps = ps + (p , )
664
+ return _join_n (ps )
665
+
666
+
667
+ @gluon .jit
668
+ def _subtiled_qk_load (config , s_tmem ):
669
+ SIZE : gl .constexpr = s_tmem .shape [1 ] // config .SPLIT_QK_LOAD_FACTOR
670
+ layout : gl .constexpr = _get_split_n_layout (config .qk_layout , config .SPLIT_QK_LOAD_FACTOR )
671
+ qks = ()
672
+ for i in gl .static_range (config .SPLIT_QK_LOAD_FACTOR ):
673
+ qks = qks + (s_tmem .slice (i * SIZE , SIZE ).load (layout ), )
674
+ return _join_n (qks )
675
+
676
+
612
677
@gluon .jit
613
678
def _softmax_inner_loop (tile_id : gl .constexpr , config , prog , #
614
679
s_consumer , corr_producer , exp_turnstile , corr_bar , #
@@ -617,7 +682,7 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
617
682
618
683
for start_n in range (lo , hi , config .BLOCK_N ):
619
684
s_tmem , s_bar , s_consumer = s_consumer .acquire ()
620
- qk = s_tmem . load (config . qk_layout )
685
+ qk = _subtiled_qk_load (config , s_tmem )
621
686
622
687
if STAGE == 2 :
623
688
col_limit_right = (offs_m - start_n + 1 )[:, None ]
@@ -635,11 +700,6 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
635
700
else :
636
701
qk = _mul_f32x2 (qk , gl .full_like (qk , config .qk_scale ))
637
702
qk = _add_f32x2 (qk , - m_ij [:, None ])
638
- qk0 , qk1 , = qk .reshape ([config .SPLIT_M , 2 , config .BLOCK_N // 2 ]).permute (0 , 2 , 1 ).split ()
639
-
640
- p_tmem = _borrow_s_as_p (config , s_tmem )
641
- BN4 : gl .constexpr = config .BLOCK_N // 4
642
- BN2 : gl .constexpr = config .BLOCK_N // 2
643
703
644
704
# Force the softmax partitions to take turns in the EX2 section. This
645
705
# prevents contention for the EX2 unit and improves utilization.
@@ -649,49 +709,27 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
649
709
# FIXME: When using FADD2 reductions, ptxas misbehaves and spills far
650
710
# below the register limit in the FADD2, FMUL2, EX2 section. Subtile by
651
711
# 4 to minimize the spilling.
652
- if config .HEAD_DIM == 64 :
653
- qk00 , qk01 = qk0 .reshape ([config .SPLIT_M , 2 , config .BLOCK_N // 4 ]).permute (0 , 2 , 1 ).split ()
654
- p00 = gl .exp2 (qk00 )
655
- p_tmem .slice (0 , BN4 ).store (p00 .to (config .dtype ))
656
- p01 = gl .exp2 (qk01 )
657
- p_tmem .slice (BN4 , BN4 ).store (p01 .to (config .dtype ))
658
- p0 = gl .join (p00 , p01 ).permute (0 , 2 , 1 ).reshape ([config .SPLIT_M , config .BLOCK_N // 2 ])
659
- p0 = gl .convert_layout (p0 , config .qk_layout )
660
- else :
661
- p0 = gl .exp2 (qk0 )
662
- p_tmem .slice (0 , BN2 ).store (p0 .to (config .dtype ))
663
-
664
- if config .HEAD_DIM == 64 :
665
- qk10 , qk11 = qk1 .reshape ([config .SPLIT_M , 2 , config .BLOCK_N // 4 ]).permute (0 , 2 , 1 ).split ()
666
- p10 = gl .exp2 (qk10 )
667
- p_tmem .slice (2 * BN4 , BN4 ).store (p10 .to (config .dtype ))
668
- p11 = gl .exp2 (qk11 )
669
- p_tmem .slice (3 * BN4 , BN4 ).store (p11 .to (config .dtype ))
670
- p1 = gl .join (p10 , p11 ).permute (0 , 2 , 1 ).reshape ([config .SPLIT_M , config .BLOCK_N // 2 ])
671
- p1 = gl .convert_layout (p1 , config .qk_layout )
672
- else :
673
- p1 = gl .exp2 (qk1 )
674
- p_tmem .slice (BN2 , BN2 ).store (p1 .to (config .dtype ))
712
+ p_tmem = _borrow_s_as_p (config , s_tmem )
713
+ p = _compute_and_store_exp2 (config , qk , p_tmem )
675
714
676
715
mbarrier .arrive (s_bar , count = 1 )
677
-
678
716
_ , corr_bar , corr_producer = corr_producer .acquire ()
679
717
680
- if config .HEAD_DIM == 64 :
718
+ if config .use_exp2_turnstile :
681
719
mbarrier .arrive (exp_bar , count = 1 )
682
720
683
721
if config .use_fadd2_reduce :
722
+ p0 , p1 = _split_n (p )
684
723
l_ij0 , l_ij1 = gl .reduce ((p0 , p1 ), axis = 1 , combine_fn = _reduce_fadd2 )
685
724
# This is a difference of 1 SASS instruction but it dramatically
686
725
# affects instruction scheduling.
726
+ alpha = gl .convert_layout (alpha , l_i0 .type .layout , assert_trivial = True )
687
727
if config .dtype == gl .float8e5 :
688
728
l_i0 , l_i1 = _pairwise_fma_f32x2 (l_i0 , alpha , l_ij0 , l_i1 , alpha , l_ij1 )
689
729
else :
690
730
l_i0 = l_i0 * alpha + l_ij0
691
731
l_i1 = l_i1 * alpha + l_ij1
692
732
else :
693
- p = gl .join (p0 , p1 ).permute (0 , 2 , 1 ).reshape ([config .SPLIT_M , config .BLOCK_N ])
694
- p = gl .convert_layout (p , config .qk_layout )
695
733
l_ij = gl .sum (p , axis = 1 )
696
734
l_i0 = l_i0 * alpha + l_ij
697
735
@@ -704,6 +742,7 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
704
742
def _softmax_tile (tile_id : gl .constexpr , config , M , desc_o , STAGE : gl .constexpr , #
705
743
s_chnl , corr_chnl , exp_turnstile ):
706
744
qk_slice_dim1 : gl .constexpr = gl .SliceLayout (1 , config .qk_layout )
745
+ sum_layout : gl .constexpr = _get_split_n_layout (config .qk_layout ) if config .use_fadd2_reduce else config .qk_layout
707
746
708
747
s_consumer = s_chnl .create_consumer ()
709
748
corr_producer = corr_chnl .create_producer ()
@@ -717,10 +756,10 @@ def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr,
717
756
offs_m += gl .arange (tile_id * config .SPLIT_M , (1 + tile_id ) * config .SPLIT_M , qk_slice_dim1 )
718
757
719
758
m_i = gl .full ([config .SPLIT_M ], - float ("inf" ), gl .float32 , qk_slice_dim1 )
720
- l_i0 = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , qk_slice_dim1 )
759
+ l_i0 = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , gl . SliceLayout ( 1 , sum_layout ) )
721
760
# Accumulate into 2 row-sums so the reduction can be performed with FADD2.
722
761
if config .use_fadd2_reduce :
723
- l_i1 = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , qk_slice_dim1 )
762
+ l_i1 = gl .full ([config .SPLIT_M ], 0.0 , gl .float32 , gl . SliceLayout ( 1 , sum_layout ) )
724
763
else :
725
764
l_i1 = 0
726
765
@@ -900,7 +939,7 @@ def attention_kernel( #
900
939
num_warps : gl .constexpr ):
901
940
qk_scale = sm_scale * 1.44269504
902
941
config = AttentionConfig (qk_scale , Z , H , N_CTX , BLOCK_M , BLOCK_N , HEAD_DIM , GROUP_SIZE_N , NUM_SMS , STAGE , #
903
- dtype , num_warps , SPLIT_D_FACTOR = 2 )
942
+ dtype , num_warps )
904
943
905
944
q_chnl = get_desc_channel (desc_q , num_buffers = 2 )
906
945
kv_chnl = get_desc_channel (desc_k , num_buffers = config .num_kv_buffers )
0 commit comments