@@ -173,6 +173,18 @@ def issue_async_tma_load(smem, bar, desc, offset):
173
173
tma .async_copy_global_to_shared (desc , [offset , 0 ], bar , smem )
174
174
175
175
176
+ @gluon .jit
177
+ def _interleave_n (a , b , size : gl .constexpr , f : gl .constexpr , i : gl .constexpr = 0 ):
178
+ if a .shape [1 ] == size :
179
+ return f (a , b , i )
180
+ else :
181
+ a0 , a1 = a .reshape ([a .shape [0 ], 2 , a .shape [1 ] // 2 ]).permute (0 , 2 , 1 ).split ()
182
+ b0 , b1 = b .reshape ([b .shape [0 ], 2 , b .shape [1 ] // 2 ]).permute (0 , 2 , 1 ).split ()
183
+ c0 = _interleave_n (a0 , b0 , size , f , i )
184
+ c1 = _interleave_n (a1 , b1 , size , f , i + a .shape [1 ] // 2 )
185
+ return gl .convert_layout (gl .join (c0 , c1 ).permute (0 , 2 , 1 ).reshape (a .shape ), a .type .layout )
186
+
187
+
176
188
# ===-----------------------------------------------------------------------===#
177
189
# Gluon Attention
178
190
# ===-----------------------------------------------------------------------===#
@@ -556,7 +568,7 @@ def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
556
568
s0_tmem , s0_bar , s0_producer = s0_producer .acquire ()
557
569
p0_tmem = _borrow_s_as_p (config , s0_tmem )
558
570
tcgen05_mma (p0_tmem , v_smem , o0_tmem , use_acc = False , mbarriers = [o0_bar ])
559
- o_init = False
571
+ o1_init = False
560
572
561
573
for _ in range (num_mmas - 1 ):
562
574
k_smem , k_bar , kv_consumer = kv_consumer .acquire ()
@@ -565,43 +577,69 @@ def _attn_fwd_mma(config, chnls, descs, M, STAGE: gl.constexpr):
565
577
o1_tmem , o1_bar , o_producer = o_producer .acquire ()
566
578
s1_tmem , s1_bar , s1_producer = s1_producer .acquire ()
567
579
p1_tmem = _borrow_s_as_p (config , s1_tmem )
568
- tcgen05_mma (p1_tmem , v_smem , o1_tmem , use_acc = o_init , mbarriers = [o1_bar , v_bar ])
569
- o_init = True
580
+ tcgen05_mma (p1_tmem , v_smem , o1_tmem , use_acc = o1_init , mbarriers = [o1_bar , v_bar ])
581
+ o1_init = True
570
582
571
583
tcgen05_mma (q1_smem , k_smem .permute ((1 , 0 )), s1_tmem , use_acc = False , mbarriers = [s1_bar , k_bar ])
572
584
573
585
v_smem , v_bar , kv_consumer = kv_consumer .acquire ()
574
586
o0_tmem , o0_bar , o_producer = o_producer .acquire ()
575
587
s0_tmem , s0_bar , s0_producer = s0_producer .acquire ()
576
588
p0_tmem = _borrow_s_as_p (config , s0_tmem )
577
- tcgen05_mma (p0_tmem , v_smem , o0_tmem , use_acc = o_init , mbarriers = [o0_bar ])
578
- o_init = True
589
+ tcgen05_mma (p0_tmem , v_smem , o0_tmem , mbarriers = [o0_bar ])
579
590
580
591
tcgen05_commit (q0_bar )
581
592
tcgen05_commit (q1_bar )
582
593
583
594
o1_tmem , o1_bar , o_producer = o_producer .acquire ()
584
595
s1_tmem , s1_bar , s1_producer = s1_producer .acquire ()
585
596
p1_tmem = _borrow_s_as_p (config , s1_tmem )
586
- tcgen05_mma (p1_tmem , v_smem , o1_tmem , use_acc = o_init , mbarriers = [o1_bar , v_bar , s0_bar , s1_bar ])
597
+ tcgen05_mma (p1_tmem , v_smem , o1_tmem , use_acc = o1_init , mbarriers = [o1_bar , v_bar , s0_bar , s1_bar ])
598
+
599
+
600
+ @gluon .jit
601
+ def _mask_inner (qk , mask , i : gl .constexpr ):
602
+ mask_i_bit = mask & (1 << i ) == 0
603
+ return gl .where (mask_i_bit , qk , - float ("inf" ))
604
+
605
+
606
+ @gluon .jit
607
+ def _mask_frag (qk , col_limit_right , s : gl .constexpr ):
608
+ col_limit_right_s = col_limit_right - s
609
+ col_limit_right_cur = max (col_limit_right_s , 0 )
610
+ mask = - 1 << col_limit_right_cur
611
+ return _interleave_n (qk , mask , 1 , _mask_inner )
612
+
613
+
614
+ @gluon .jit
615
+ def _mask_bits (qk , col_limit_right ):
616
+ # FIXME: This is a more concise implementation (which compiles faster) but
617
+ # it results in slightly slower code due to the lack of interleaving.
618
+ offs_n = gl .arange (0 , qk .shape [1 ], layout = gl .SliceLayout (0 , qk .type .layout ))[None , :]
619
+ s = offs_n & ~ 0xf
620
+ i = offs_n & 0xf
621
+
622
+ col_lim_right_s = col_limit_right - s
623
+ col_lim_right_cur = max (col_lim_right_s , 0 )
624
+ mask = - 1 << col_lim_right_cur
625
+ mask_i_bit = (mask & (1 << i )) == 0
626
+ return gl .where (mask_i_bit , qk , - float ("inf" ))
587
627
588
628
589
629
@gluon .jit
590
630
def _softmax_inner_loop (tile_id : gl .constexpr , config , prog , #
591
631
s_consumer , corr_producer , exp_turnstile , corr_bar , #
592
- offs_m , offs_n , m_i , l_i0 , l_i1 , STAGE : gl .constexpr ):
632
+ offs_m , m_i , l_i0 , l_i1 , STAGE : gl .constexpr ):
593
633
lo , hi = prog .get_loop_bounds (STAGE )
594
634
595
635
for start_n in range (lo , hi , config .BLOCK_N ):
596
636
s_tmem , s_bar , s_consumer = s_consumer .acquire ()
597
637
qk = s_tmem .load (config .qk_layout )
598
638
599
639
if STAGE == 2 :
600
- # Prevent LLVM from hoisting the partial sums, which triggers spilling.
601
- offs_n = gl .inline_asm_elementwise ("mov.b32 $0, $0;" , "=r,r" , [offs_n ], dtype = gl .int32 , is_pure = True ,
602
- pack = 1 )
603
- mask = offs_m [:, None ] < (start_n + offs_n [None , :])
604
- qk = gl .where (mask , - 1.0e8 , qk )
640
+ col_limit_right = (offs_m - start_n + 1 )[:, None ].broadcast_to (qk .shape )
641
+ qk = _interleave_n (qk , col_limit_right , 16 , _mask_frag )
642
+
605
643
m_ij = gl .maximum (m_i , gl .max (qk , 1 ) * config .qk_scale )
606
644
alpha = gl .exp2 (m_i - m_ij )
607
645
@@ -682,11 +720,8 @@ def _softmax_inner_loop(tile_id: gl.constexpr, config, prog, #
682
720
@gluon .jit
683
721
def _softmax_tile (tile_id : gl .constexpr , config , M , desc_o , STAGE : gl .constexpr , #
684
722
s_chnl , corr_chnl , exp_turnstile ):
685
- qk_slice_dim0 : gl .constexpr = gl .SliceLayout (0 , config .qk_layout )
686
723
qk_slice_dim1 : gl .constexpr = gl .SliceLayout (1 , config .qk_layout )
687
724
688
- offs_n = gl .arange (0 , config .BLOCK_N , qk_slice_dim0 )
689
-
690
725
s_consumer = s_chnl .create_consumer ()
691
726
corr_producer = corr_chnl .create_producer ()
692
727
_ , corr_bar , corr_producer = corr_producer .acquire ()
@@ -709,11 +744,11 @@ def _softmax_tile(tile_id: gl.constexpr, config, M, desc_o, STAGE: gl.constexpr,
709
744
if STAGE & 1 :
710
745
m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = _softmax_inner_loop ( #
711
746
tile_id , config , prog , s_consumer , corr_producer , exp_turnstile , corr_bar , #
712
- offs_m , offs_n , m_i , l_i0 , l_i1 , STAGE = 4 - STAGE )
747
+ offs_m , m_i , l_i0 , l_i1 , STAGE = 4 - STAGE )
713
748
if STAGE & 2 :
714
749
m_i , l_i0 , l_i1 , corr_bar , s_consumer , corr_producer , exp_turnstile = _softmax_inner_loop ( #
715
750
tile_id , config , prog , s_consumer , corr_producer , exp_turnstile , corr_bar , #
716
- offs_m , offs_n , m_i , l_i0 , l_i1 , STAGE = 2 )
751
+ offs_m , m_i , l_i0 , l_i1 , STAGE = 2 )
717
752
718
753
if config .use_fadd2_reduce :
719
754
l_i = l_i0 + l_i1
0 commit comments