@@ -42,6 +42,23 @@ def _host_descriptor_pre_hook(nargs):
42
42
),
43
43
]
44
44
45
+ configs_persistent = [
46
+ triton .Config (
47
+ {
48
+ "BLOCK_M" : 256 ,
49
+ "BLOCK_N" : 128 ,
50
+ "NUM_BUFFERS_Q" : 1 ,
51
+ "NUM_BUFFERS_KV" : 3 ,
52
+ "NUM_BUFFERS_QK" : 1 ,
53
+ "NUM_MMA_GROUPS" : 2 ,
54
+ "NUM_MMA_SLICES" : 2 ,
55
+ },
56
+ num_stages = 0 ,
57
+ num_warps = 4 ,
58
+ pre_hook = _host_descriptor_pre_hook ,
59
+ ),
60
+ ]
61
+
45
62
46
63
@triton .jit
47
64
def _get_bufidx_phase (accum_cnt , NUM_BUFFERS_KV ):
@@ -63,6 +80,47 @@ def _compute_offsets(H, N_CTX, BLOCK_M):
63
80
return start_m , off_hz , lo , hi , qo_offset_y , kv_offset_y
64
81
65
82
83
+ @triton .jit
84
+ def _fma_f32x2 (a , b , c ):
85
+ return tl .inline_asm_elementwise (
86
+ """
87
+ {
88
+ .reg .b64 ra, rb, rc, rd;
89
+ mov.b64 ra, { $2, $3 };
90
+ mov.b64 rb, { $4, $5 };
91
+ mov.b64 rc, { $6, $7 };
92
+ fma.rn.f32x2 rd, ra, rb, rc;
93
+ mov.b64 { $0, $1 }, rd;
94
+ }
95
+ """ ,
96
+ "=r,=r,r,r,r,r,r,r" ,
97
+ [a , b , c ],
98
+ dtype = tl .float32 ,
99
+ is_pure = True ,
100
+ pack = 2 ,
101
+ )
102
+
103
+
104
+ @triton .jit
105
+ def _mul_f32x2 (a , b ):
106
+ return tl .inline_asm_elementwise (
107
+ """
108
+ {
109
+ .reg .b64 ra, rb, rc;
110
+ mov.b64 ra, { $2, $3 };
111
+ mov.b64 rb, { $4, $5 };
112
+ mul.f32x2 rc, ra, rb;
113
+ mov.b64 { $0, $1 }, rc;
114
+ }
115
+ """ ,
116
+ "=r,=r,r,r,r,r" ,
117
+ [a , b ],
118
+ dtype = tl .float32 ,
119
+ is_pure = True ,
120
+ pack = 2 ,
121
+ )
122
+
123
+
66
124
@triton .autotune (configs = configs , key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" ])
67
125
@triton .jit
68
126
def _attn_fwd_ws (
@@ -472,7 +530,7 @@ def _compute_offsets_persistent(tile_idx, n_tile_num, H, N_CTX, BLOCK_M):
472
530
return start_m , off_hz , lo , hi , qo_offset_y , kv_offset_y
473
531
474
532
475
- @triton .autotune (configs = configs , key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" ])
533
+ @triton .autotune (configs = configs_persistent , key = ["N_CTX" , "HEAD_DIM" , "FP8_OUTPUT" ])
476
534
@triton .jit
477
535
def _attn_fwd_ws_persistent (
478
536
sm_scale ,
@@ -492,6 +550,7 @@ def _attn_fwd_ws_persistent(
492
550
NUM_BUFFERS_KV : tl .constexpr , #
493
551
NUM_BUFFERS_QK : tl .constexpr , #
494
552
NUM_MMA_GROUPS : tl .constexpr , #
553
+ NUM_MMA_SLICES : tl .constexpr , #
495
554
):
496
555
tl .static_assert (BLOCK_N <= HEAD_DIM )
497
556
tl .static_assert (NUM_MMA_GROUPS == 2 )
@@ -593,24 +652,27 @@ def _attn_fwd_ws_persistent(
593
652
)
594
653
for _ in tl .range (lo , hi , BLOCK_N ):
595
654
_ , phase = _get_bufidx_phase (accum_cnt , 1 )
596
- for cid in tl .range (
597
- 0 , NUM_MMA_GROUPS , loop_unroll_factor = NUM_MMA_GROUPS
598
- ):
655
+ for cid in tl .static_range (0 , NUM_MMA_GROUPS ):
599
656
# -- update output accumulator --
600
657
tlx .barrier_wait (alpha_fulls [cid ], phase )
601
658
# Use alpha[0] for cid=0, and alpha[HEAD_DIM] for cid=1
602
659
alpha_1 = tlx .local_load (alpha_tiles [cid * HEAD_DIM ])
603
660
tlx .barrier_arrive (alpha_empties [cid ])
604
- acc = tlx .local_load (acc_tiles [cid ])
605
- acc = acc * alpha_1
606
- tlx .local_store (acc_tiles [cid ], acc )
661
+ for slice_id in tl .static_range (0 , NUM_MMA_SLICES ):
662
+ subslice = tlx .subslice (
663
+ acc_tiles [cid ],
664
+ HEAD_DIM * slice_id // NUM_MMA_SLICES ,
665
+ HEAD_DIM // NUM_MMA_SLICES ,
666
+ )
667
+ acc = tlx .local_load (subslice )
668
+ # acc = acc * alpha_1
669
+ acc = _mul_f32x2 (acc , alpha_1 )
670
+ tlx .local_store (subslice , acc )
607
671
tlx .barrier_arrive (acc_fulls [cid ])
608
672
accum_cnt += 1
609
673
610
674
_ , phase = _get_bufidx_phase (i , 1 )
611
- for cid in tl .range (
612
- 0 , NUM_MMA_GROUPS , loop_unroll_factor = NUM_MMA_GROUPS
613
- ):
675
+ for cid in tl .static_range (0 , NUM_MMA_GROUPS ):
614
676
# epilogue
615
677
tlx .barrier_wait (l_fulls [cid ], phase )
616
678
# Use l[1]/l[1+HEAD_DIM] and m[2][2 + HEAD_DIM]
@@ -646,9 +708,7 @@ def _attn_fwd_ws_persistent(
646
708
_compute_offsets_persistent (tile_idx , n_tile_num , H , N_CTX , BLOCK_M )
647
709
)
648
710
_ , phase = _get_bufidx_phase (i , 1 )
649
- for cid in tl .range (
650
- 0 , NUM_MMA_GROUPS , loop_unroll_factor = NUM_MMA_GROUPS
651
- ):
711
+ for cid in tl .static_range (0 , NUM_MMA_GROUPS ):
652
712
tlx .barrier_wait (o_fulls [cid ], phase )
653
713
tlx .fence_async_shared ()
654
714
tlx .barrier_arrive (o_empties [cid ])
@@ -661,7 +721,7 @@ def _attn_fwd_ws_persistent(
661
721
tile_idx += num_progs
662
722
663
723
# softmax groups
664
- with tlx .async_task (num_warps = 4 , registers = 152 , replicate = NUM_MMA_GROUPS ):
724
+ with tlx .async_task (num_warps = 4 , registers = 168 , replicate = NUM_MMA_GROUPS ):
665
725
accum_cnt_qk = 0
666
726
for i in range (0 , tiles_per_sm ):
667
727
# initialize offsets
@@ -691,7 +751,7 @@ def _attn_fwd_ws_persistent(
691
751
tlx .local_store (alpha_tiles [cid * HEAD_DIM ], alpha [:, None ])
692
752
tlx .barrier_arrive (alpha_fulls [cid ])
693
753
694
- qk = qk * qk_scale - m_ij [:, None ]
754
+ qk = _fma_f32x2 ( qk , qk_scale , - m_ij [:, None ])
695
755
p = tl .math .exp2 (qk )
696
756
l_ij = tl .sum (p , 1 )
697
757
p = p .to (tlx .dtype_of (desc_v ))
0 commit comments