@@ -37,14 +37,16 @@ def _attn_fwd_subtile(
37
37
start_n ,
38
38
offs_n ,
39
39
qk_scale ,
40
- l_i ,
40
+ l_i0 ,
41
+ l_i1 , # used when FADD2_REDUCE is true
41
42
m_i ,
42
43
acc ,
43
44
v ,
44
45
dtype : tl .constexpr ,
45
46
STAGE : tl .constexpr ,
46
47
SUBTILING : tl .constexpr ,
47
48
VECT_MUL : tl .constexpr ,
49
+ FADD2_REDUCE : tl .constexpr ,
48
50
):
49
51
qk = tl .dot (q , k )
50
52
if STAGE == 2 :
@@ -61,7 +63,8 @@ def _attn_fwd_subtile(
61
63
p = tl .math .exp2 (qk )
62
64
# -- compute correction factor
63
65
alpha = tl .math .exp2 (m_i - m_ij )
64
- l_ij = tl .sum (p , 1 )
66
+ if not FADD2_REDUCE :
67
+ l_ij = tl .sum (p , 1 )
65
68
66
69
# -- update output accumulator --
67
70
BM : tl .constexpr = acc .shape [0 ]
@@ -79,24 +82,37 @@ def _attn_fwd_subtile(
79
82
else :
80
83
acc = acc * alpha [:, None ]
81
84
85
+ # update m_i and l_i
86
+ # place this at the end of the loop to reduce register pressure
87
+ PM : tl .constexpr = p .shape [0 ]
88
+ PN : tl .constexpr = p .shape [1 ]
89
+ if FADD2_REDUCE :
90
+ p0 , p1 = p .reshape ([PM , 2 , PN // 2 ]).permute (0 , 2 , 1 ).split ()
91
+ l_ij0 , l_ij1 = tl .reduce ((p0 , p1 ), axis = 1 , combine_fn = _reduce_fadd2 )
92
+ l_i0 = l_i0 * alpha + l_ij0
93
+ l_i1 = l_i1 * alpha + l_ij1
94
+
95
+ # We can potentially move these to be before updating l_ij, so the dot
96
+ # is not blocked.
82
97
# prepare p and v for the dot
83
98
p = p .to (dtype )
84
99
# note that this non transposed v for FP8 is only supported on Blackwell
85
100
acc = tl .dot (p , v , acc )
86
- # update m_i and l_i
87
- # place this at the end of the loop to reduce register pressure
88
- l_i = l_i * alpha + l_ij
101
+ if not FADD2_REDUCE :
102
+ l_i0 = l_i0 * alpha + l_ij
89
103
m_i = m_ij
90
104
91
- return l_i , m_i , acc
105
+ return l_i0 , l_i1 , m_i , acc
92
106
93
107
94
108
@triton .jit
95
109
def _attn_fwd_inner_oss_dp (
96
110
acc0 ,
97
111
acc1 ,
98
112
l_i0 ,
113
+ l_i0_1 ,
99
114
l_i1 ,
115
+ l_i1_1 ,
100
116
m_i0 ,
101
117
m_i1 ,
102
118
q0 ,
@@ -118,6 +134,7 @@ def _attn_fwd_inner_oss_dp(
118
134
warp_specialize : tl .constexpr ,
119
135
SUBTILING : tl .constexpr ,
120
136
VECT_MUL : tl .constexpr ,
137
+ FADD2_REDUCE : tl .constexpr ,
121
138
):
122
139
# range of values handled by this stage
123
140
if STAGE == 1 :
@@ -139,42 +156,46 @@ def _attn_fwd_inner_oss_dp(
139
156
k = desc_k .load ([offsetkv_y , 0 ]).T
140
157
v = desc_v .load ([offsetkv_y , 0 ])
141
158
142
- l_i0 , m_i0 , acc0 = _attn_fwd_subtile (
159
+ l_i0 , l_i0_1 , m_i0 , acc0 = _attn_fwd_subtile (
143
160
q0 ,
144
161
k ,
145
162
offs_m0 ,
146
163
start_n ,
147
164
offs_n ,
148
165
qk_scale ,
149
166
l_i0 ,
167
+ l_i0_1 ,
150
168
m_i0 ,
151
169
acc0 ,
152
170
v ,
153
171
dtype ,
154
172
STAGE ,
155
173
SUBTILING ,
156
174
VECT_MUL ,
175
+ FADD2_REDUCE ,
157
176
)
158
- l_i1 , m_i1 , acc1 = _attn_fwd_subtile (
177
+ l_i1 , l_i1_1 , m_i1 , acc1 = _attn_fwd_subtile (
159
178
q1 ,
160
179
k ,
161
180
offs_m1 ,
162
181
start_n ,
163
182
offs_n ,
164
183
qk_scale ,
165
184
l_i1 ,
185
+ l_i1_1 ,
166
186
m_i1 ,
167
187
acc1 ,
168
188
v ,
169
189
dtype ,
170
190
STAGE ,
171
191
SUBTILING ,
172
192
VECT_MUL ,
193
+ FADD2_REDUCE ,
173
194
)
174
195
175
196
offsetkv_y += BLOCK_N
176
197
177
- return acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1
198
+ return acc0 , acc1 , l_i0 , l_i0_1 , l_i1 , l_i1_1 , m_i0 , m_i1
178
199
179
200
180
201
def _host_descriptor_pre_hook (nargs ):
@@ -208,6 +229,7 @@ def _host_descriptor_pre_hook(nargs):
208
229
"occupancy" : occ ,
209
230
"SUBTILING" : subtile ,
210
231
"VECT_MUL" : vectmul ,
232
+ "FADD2_REDUCE" : add2reduce ,
211
233
},
212
234
pre_hook = _host_descriptor_pre_hook ,
213
235
minRegAutoWS = 24 ,
@@ -217,12 +239,19 @@ def _host_descriptor_pre_hook(nargs):
217
239
for BN in [64 , 128 ]
218
240
for occ in [1 , 2 ]
219
241
for subtile in [True ]
220
- for vectmul in [True ]
242
+ for vectmul in [False ]
243
+ for add2reduce in [False ]
221
244
]
222
245
else :
223
246
configs = [
224
247
triton .Config (
225
- {"BLOCK_M" : BM , "BLOCK_N" : BN , "SUBTILING" : subtile , "VECT_MUL" : vectmul },
248
+ {
249
+ "BLOCK_M" : BM ,
250
+ "BLOCK_N" : BN ,
251
+ "SUBTILING" : subtile ,
252
+ "VECT_MUL" : vectmul ,
253
+ "FADD2_REDUCE" : add2reduce ,
254
+ },
226
255
num_stages = s ,
227
256
num_warps = w ,
228
257
pre_hook = _host_descriptor_pre_hook ,
@@ -236,6 +265,7 @@ def _host_descriptor_pre_hook(nargs):
236
265
for w in [4 ]
237
266
for subtile in [True ]
238
267
for vectmul in [False ]
268
+ for add2reduce in [False ]
239
269
]
240
270
241
271
@@ -306,6 +336,26 @@ def _fma_f32x2(a, b, c):
306
336
)
307
337
308
338
339
+ @triton .jit
340
+ def _reduce_fadd2 (p0a , p1a , p0b , p1b ):
341
+ return tl .inline_asm_elementwise (
342
+ """
343
+ {
344
+ .reg .b64 rc, ra, rb;
345
+ mov.b64 ra, { $2, $4 };
346
+ mov.b64 rb, { $3, $5 };
347
+ add.f32x2 rc, ra, rb;
348
+ mov.b64 { $0, $1 }, rc;
349
+ }
350
+ """ ,
351
+ "=r,=r,r,r,r,r" ,
352
+ [p0a , p0b , p1a , p1b ],
353
+ dtype = [tl .float32 , tl .float32 ],
354
+ is_pure = True ,
355
+ pack = 1 ,
356
+ )
357
+
358
+
309
359
@triton .jit
310
360
def _attn_fwd_tma_dp (
311
361
sm_scale ,
@@ -328,8 +378,9 @@ def _attn_fwd_tma_dp(
328
378
dtype : tl .constexpr ,
329
379
SUBTILING : tl .constexpr ,
330
380
VECT_MUL : tl .constexpr ,
381
+ FADD2_REDUCE : tl .constexpr ,
331
382
):
332
- tl .static_assert (BLOCK_N <= HEAD_DIM )
383
+ # tl.static_assert(BLOCK_N <= HEAD_DIM)
333
384
start_m = pid # tl.program_id(0)
334
385
# off_hz = tl.program_id(1)
335
386
off_z = off_hz // H
@@ -343,11 +394,11 @@ def _attn_fwd_tma_dp(
343
394
offs_n = tl .arange (0 , BLOCK_N )
344
395
345
396
m_i0 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) - float ("inf" )
346
- l_i0 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) + 1.0
397
+ l_i0_0 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) + 1.0
347
398
acc0 = tl .zeros ([BLOCK_M // 2 , HEAD_DIM ], dtype = tl .float32 )
348
399
349
400
m_i1 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) - float ("inf" )
350
- l_i1 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) + 1.0
401
+ l_i1_0 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 ) + 1.0
351
402
acc1 = tl .zeros ([BLOCK_M // 2 , HEAD_DIM ], dtype = tl .float32 )
352
403
353
404
qk_scale = sm_scale
@@ -356,12 +407,21 @@ def _attn_fwd_tma_dp(
356
407
q0 = desc_q .load ([qo_offset_y , 0 ])
357
408
q1 = desc_q .load ([qo_offset_y + BLOCK_M // 2 , 0 ])
358
409
410
+ if FADD2_REDUCE :
411
+ l_i0_1 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 )
412
+ l_i1_1 = tl .zeros ([BLOCK_M // 2 ], dtype = tl .float32 )
413
+ else :
414
+ l_i0_1 = 0
415
+ l_i1_1 = 0
416
+
359
417
if STAGE & 1 :
360
- acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 = _attn_fwd_inner_oss_dp (
418
+ acc0 , acc1 , l_i0_0 , l_i0_1 , l_i1_0 , l_i1_1 , m_i0 , m_i1 = _attn_fwd_inner_oss_dp (
361
419
acc0 ,
362
420
acc1 ,
363
- l_i0 ,
364
- l_i1 ,
421
+ l_i0_0 ,
422
+ l_i0_1 ,
423
+ l_i1_0 ,
424
+ l_i1_1 ,
365
425
m_i0 ,
366
426
m_i1 ,
367
427
q0 ,
@@ -383,13 +443,16 @@ def _attn_fwd_tma_dp(
383
443
warp_specialize ,
384
444
SUBTILING ,
385
445
VECT_MUL ,
446
+ FADD2_REDUCE ,
386
447
)
387
448
if STAGE & 2 :
388
- acc0 , acc1 , l_i0 , l_i1 , m_i0 , m_i1 = _attn_fwd_inner_oss_dp (
449
+ acc0 , acc1 , l_i0_0 , l_i0_1 , l_i1_0 , l_i1_1 , m_i0 , m_i1 = _attn_fwd_inner_oss_dp (
389
450
acc0 ,
390
451
acc1 ,
391
- l_i0 ,
392
- l_i1 ,
452
+ l_i0_0 ,
453
+ l_i0_1 ,
454
+ l_i1_0 ,
455
+ l_i1_1 ,
393
456
m_i0 ,
394
457
m_i1 ,
395
458
q0 ,
@@ -411,8 +474,16 @@ def _attn_fwd_tma_dp(
411
474
warp_specialize ,
412
475
SUBTILING ,
413
476
VECT_MUL ,
477
+ FADD2_REDUCE ,
414
478
)
415
479
480
+ if FADD2_REDUCE :
481
+ l_i0 = l_i0_0 + l_i0_1
482
+ l_i1 = l_i1_0 + l_i1_1
483
+ else :
484
+ l_i0 = l_i0_0
485
+ l_i1 = l_i1_0
486
+
416
487
m_i0 += tl .math .log2 (l_i0 )
417
488
acc0 = acc0 / l_i0 [:, None ]
418
489
m_ptrs0 = M + off_hz * N_CTX + offs_m0
@@ -451,6 +522,7 @@ def _attn_fwd(
451
522
dtype : tl .constexpr ,
452
523
SUBTILING : tl .constexpr ,
453
524
VECT_MUL : tl .constexpr ,
525
+ FADD2_REDUCE : tl .constexpr ,
454
526
):
455
527
pid = tl .program_id (0 )
456
528
off_hz = tl .program_id (1 )
@@ -475,6 +547,7 @@ def _attn_fwd(
475
547
dtype ,
476
548
SUBTILING ,
477
549
VECT_MUL ,
550
+ FADD2_REDUCE ,
478
551
)
479
552
480
553
@@ -493,7 +566,7 @@ def _attn_fwd_persist(
493
566
desc_k ,
494
567
desc_v ,
495
568
desc_o ,
496
- N_CTX , #
569
+ N_CTX , #: tl.constexpr, #
497
570
HEAD_DIM : tl .constexpr , #
498
571
BLOCK_M : tl .constexpr , #
499
572
BLOCK_N : tl .constexpr , #
@@ -504,6 +577,7 @@ def _attn_fwd_persist(
504
577
dtype : tl .constexpr ,
505
578
SUBTILING : tl .constexpr ,
506
579
VECT_MUL : tl .constexpr ,
580
+ FADD2_REDUCE : tl .constexpr ,
507
581
):
508
582
n_tile_num = tl .cdiv (N_CTX , BLOCK_M )
509
583
prog_id = tl .program_id (0 )
@@ -540,6 +614,7 @@ def _attn_fwd_persist(
540
614
dtype ,
541
615
SUBTILING ,
542
616
VECT_MUL ,
617
+ FADD2_REDUCE ,
543
618
)
544
619
tile_idx += num_progs
545
620
0 commit comments