@@ -173,6 +173,18 @@ def issue_async_tma_load(smem, bar, desc, offset):
173
173
tma .async_copy_global_to_shared (desc , [offset , 0 ], bar , smem )
174
174
175
175
176
+ @gluon .jit
177
+ def _interleave_n (a , b , size : gl .constexpr , f : gl .constexpr , i : gl .constexpr = 0 ):
178
+ if a .shape [1 ] == size :
179
+ return f (a , b , i )
180
+ else :
181
+ a0 , a1 = a .reshape ([a .shape [0 ], 2 , a .shape [1 ] // 2 ]).permute (0 , 2 , 1 ).split ()
182
+ b0 , b1 = b .reshape ([b .shape [0 ], 2 , b .shape [1 ] // 2 ]).permute (0 , 2 , 1 ).split ()
183
+ c0 = _interleave_n (a0 , b0 , size , f , i )
184
+ c1 = _interleave_n (a1 , b1 , size , f , i + a .shape [1 ] // 2 )
185
+ return gl .convert_layout (gl .join (c0 , c1 ).permute (0 , 2 , 1 ).reshape (a .shape ), a .type .layout )
186
+
187
+
176
188
# ===-----------------------------------------------------------------------===#
177
189
# Gluon Attention
178
190
# ===-----------------------------------------------------------------------===#
@@ -586,22 +598,49 @@ def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
586
598
tcgen05_mma (p1_tmem , v_smem , o1_tmem , use_acc = o_init , mbarriers = [o1_bar , v_bar , s0_bar , s1_bar ])
587
599
588
600
601
+ @gluon .jit
602
+ def _mask_inner (qk , mask , i : gl .constexpr ):
603
+ mask_i_bit = mask & (1 << i ) == 0
604
+ return gl .where (mask_i_bit , qk , - float ("inf" ))
605
+
606
+
607
+ @gluon .jit
608
+ def _mask_frag (qk , col_limit_right , s : gl .constexpr ):
609
+ col_limit_right_s = col_limit_right - s
610
+ col_limit_right_cur = max (col_limit_right_s , 0 )
611
+ mask = - 1 << col_limit_right_cur
612
+ return _interleave_n (qk , mask , 1 , _mask_inner )
613
+
614
+
615
+ @gluon .jit
616
+ def _mask_bits (qk , col_limit_right ):
617
+ # FIXME: This is a more concise implementation (which compiles faster) but
618
+ # it results in slightly slower code due to the lack of interleaving.
619
+ offs_n = gl .arange (0 , qk .shape [1 ], layout = gl .SliceLayout (0 , qk .type .layout ))[None , :]
620
+ s = offs_n & ~ 0xf
621
+ i = offs_n & 0xf
622
+
623
+ col_lim_right_s = col_limit_right - s
624
+ col_lim_right_cur = max (col_lim_right_s , 0 )
625
+ mask = - 1 << col_lim_right_cur
626
+ mask_i_bit = (mask & (1 << i )) == 0
627
+ return gl .where (mask_i_bit , qk , - float ("inf" ))
628
+
629
+
589
630
@gluon .jit
590
631
def _softmax_inner_loop (tile_id : gl .constexpr , config , prog , #
591
632
s_consumer , corr_producer , exp_turnstile , corr_bar , #
592
- offs_m , offs_n , m_i , l_i0 , l_i1 , STAGE : gl .constexpr ):
633
+ offs_m , m_i , l_i0 , l_i1 , STAGE : gl .constexpr ):
593
634
lo , hi = prog .get_loop_bounds (STAGE )
594
635
595
636
for start_n in range (lo , hi , config .BLOCK_N ):
596
637
s_tmem , s_bar , s_consumer = s_consumer .acquire ()
597
638
qk = s_tmem .load (config .qk_layout )
598
639
599
640
if STAGE == 2 :
600
- # Prevent LLVM from hoisting the partial sums, which triggers spilling.
601
- offs_n = gl .inline_asm_elementwise ("mov.b32 $0, $0;" , "=r,r" , [offs_n ], dtype = gl .int32 , is_pure = True ,
602
- pack = 1 )
603
- mask = offs_m [:, None ] < (start_n + offs_n [None , :])
604
- qk = gl .where (mask , - 1.0e8 , qk )
641
+ col_limit_right = (offs_m - start_n + 1 )[:, None ].broadcast_to (qk .shape )
642
+ qk = _interleave_n (qk , col_limit_right , 16 , _mask_frag )
643
+
605
644
m_ij = gl .maximum (m_i , gl .max (qk , 1 ) * config .qk_scale )
606
645
alpha = gl .exp2 (m_i - m_ij )
607
646
@@ -682,11 +721,8 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
682
721
@gluon .jit
683
722
def _softmax_tile (tile_id : gl .constexpr , config , M , desc_o , STAGE : gl .constexpr , #
684
723
s_chnl , corr_chnl , exp_turnstile ):
685
- qk_slice_dim0 : gl .constexpr = gl .SliceLayout (0 , config .qk_layout )
686
724
qk_slice_dim1 : gl .constexpr = gl .SliceLayout (1 , config .qk_layout )
687
725
688
- offs_n = gl .arange (0 , config .BLOCK_N , qk_slice_dim0 )
689
-
690
726
s_consumer = s_chnl .create_consumer ()
691
727
corr_producer = corr_chnl .create_producer ()
692
728
_ , corr_bar , corr_producer = corr_producer .acquire ()
@@ -709,11 +745,11 @@ def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr,
709
745
if STAGE & 1 :
710
746
m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = _softmax_inner_loop ( #
711
747
tile_id , config , prog , s_consumer , corr_producer , exp_turnstile , corr_bar , #
712
- offs_m , offs_n , m_i , l_i0 , l_i1 , STAGE = 4 - STAGE )
748
+ offs_m , m_i , l_i0 , l_i1 , STAGE = 4 - STAGE )
713
749
if STAGE & 2 :
714
750
m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = _softmax_inner_loop ( #
715
751
tile_id , config , prog , s_consumer , corr_producer , exp_turnstile , corr_bar , #
716
- offs_m , offs_n , m_i , l_i0 , l_i1 , STAGE = 2 )
752
+ offs_m , m_i , l_i0 , l_i1 , STAGE = 2 )
717
753
718
754
if config .use_fadd2_reduce :
719
755
l_i = l_i0 + l_i1
0 commit comments