Skip to content

Commit 1d43f6d

Browse files
authored
[FA][autoWS] add support for parallel reduction (#534)
Summary: Test Plan:
1 parent b23d937 commit 1d43f6d

File tree

1 file changed

+96
-21
lines changed

1 file changed

+96
-21
lines changed

tritonbench/kernels/blackwell_triton_fused_attention_dp.py

Lines changed: 96 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@ def _attn_fwd_subtile(
3737
start_n,
3838
offs_n,
3939
qk_scale,
40-
l_i,
40+
l_i0,
41+
l_i1, # used when FADD2_REDUCE is true
4142
m_i,
4243
acc,
4344
v,
4445
dtype: tl.constexpr,
4546
STAGE: tl.constexpr,
4647
SUBTILING: tl.constexpr,
4748
VECT_MUL: tl.constexpr,
49+
FADD2_REDUCE: tl.constexpr,
4850
):
4951
qk = tl.dot(q, k)
5052
if STAGE == 2:
@@ -61,7 +63,8 @@ def _attn_fwd_subtile(
6163
p = tl.math.exp2(qk)
6264
# -- compute correction factor
6365
alpha = tl.math.exp2(m_i - m_ij)
64-
l_ij = tl.sum(p, 1)
66+
if not FADD2_REDUCE:
67+
l_ij = tl.sum(p, 1)
6568

6669
# -- update output accumulator --
6770
BM: tl.constexpr = acc.shape[0]
@@ -79,24 +82,37 @@ def _attn_fwd_subtile(
7982
else:
8083
acc = acc * alpha[:, None]
8184

85+
# update m_i and l_i
86+
# place this at the end of the loop to reduce register pressure
87+
PM: tl.constexpr = p.shape[0]
88+
PN: tl.constexpr = p.shape[1]
89+
if FADD2_REDUCE:
90+
p0, p1 = p.reshape([PM, 2, PN // 2]).permute(0, 2, 1).split()
91+
l_ij0, l_ij1 = tl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2)
92+
l_i0 = l_i0 * alpha + l_ij0
93+
l_i1 = l_i1 * alpha + l_ij1
94+
95+
# We can potentially move these to be before updating l_ij, so the dot
96+
# is not blocked.
8297
# prepare p and v for the dot
8398
p = p.to(dtype)
8499
# note that this non transposed v for FP8 is only supported on Blackwell
85100
acc = tl.dot(p, v, acc)
86-
# update m_i and l_i
87-
# place this at the end of the loop to reduce register pressure
88-
l_i = l_i * alpha + l_ij
101+
if not FADD2_REDUCE:
102+
l_i0 = l_i0 * alpha + l_ij
89103
m_i = m_ij
90104

91-
return l_i, m_i, acc
105+
return l_i0, l_i1, m_i, acc
92106

93107

94108
@triton.jit
95109
def _attn_fwd_inner_oss_dp(
96110
acc0,
97111
acc1,
98112
l_i0,
113+
l_i0_1,
99114
l_i1,
115+
l_i1_1,
100116
m_i0,
101117
m_i1,
102118
q0,
@@ -118,6 +134,7 @@ def _attn_fwd_inner_oss_dp(
118134
warp_specialize: tl.constexpr,
119135
SUBTILING: tl.constexpr,
120136
VECT_MUL: tl.constexpr,
137+
FADD2_REDUCE: tl.constexpr,
121138
):
122139
# range of values handled by this stage
123140
if STAGE == 1:
@@ -139,42 +156,46 @@ def _attn_fwd_inner_oss_dp(
139156
k = desc_k.load([offsetkv_y, 0]).T
140157
v = desc_v.load([offsetkv_y, 0])
141158

142-
l_i0, m_i0, acc0 = _attn_fwd_subtile(
159+
l_i0, l_i0_1, m_i0, acc0 = _attn_fwd_subtile(
143160
q0,
144161
k,
145162
offs_m0,
146163
start_n,
147164
offs_n,
148165
qk_scale,
149166
l_i0,
167+
l_i0_1,
150168
m_i0,
151169
acc0,
152170
v,
153171
dtype,
154172
STAGE,
155173
SUBTILING,
156174
VECT_MUL,
175+
FADD2_REDUCE,
157176
)
158-
l_i1, m_i1, acc1 = _attn_fwd_subtile(
177+
l_i1, l_i1_1, m_i1, acc1 = _attn_fwd_subtile(
159178
q1,
160179
k,
161180
offs_m1,
162181
start_n,
163182
offs_n,
164183
qk_scale,
165184
l_i1,
185+
l_i1_1,
166186
m_i1,
167187
acc1,
168188
v,
169189
dtype,
170190
STAGE,
171191
SUBTILING,
172192
VECT_MUL,
193+
FADD2_REDUCE,
173194
)
174195

175196
offsetkv_y += BLOCK_N
176197

177-
return acc0, acc1, l_i0, l_i1, m_i0, m_i1
198+
return acc0, acc1, l_i0, l_i0_1, l_i1, l_i1_1, m_i0, m_i1
178199

179200

180201
def _host_descriptor_pre_hook(nargs):
@@ -208,6 +229,7 @@ def _host_descriptor_pre_hook(nargs):
208229
"occupancy": occ,
209230
"SUBTILING": subtile,
210231
"VECT_MUL": vectmul,
232+
"FADD2_REDUCE": add2reduce,
211233
},
212234
pre_hook=_host_descriptor_pre_hook,
213235
minRegAutoWS=24,
@@ -217,12 +239,19 @@ def _host_descriptor_pre_hook(nargs):
217239
for BN in [64, 128]
218240
for occ in [1, 2]
219241
for subtile in [True]
220-
for vectmul in [True]
242+
for vectmul in [False]
243+
for add2reduce in [False]
221244
]
222245
else:
223246
configs = [
224247
triton.Config(
225-
{"BLOCK_M": BM, "BLOCK_N": BN, "SUBTILING": subtile, "VECT_MUL": vectmul},
248+
{
249+
"BLOCK_M": BM,
250+
"BLOCK_N": BN,
251+
"SUBTILING": subtile,
252+
"VECT_MUL": vectmul,
253+
"FADD2_REDUCE": add2reduce,
254+
},
226255
num_stages=s,
227256
num_warps=w,
228257
pre_hook=_host_descriptor_pre_hook,
@@ -236,6 +265,7 @@ def _host_descriptor_pre_hook(nargs):
236265
for w in [4]
237266
for subtile in [True]
238267
for vectmul in [False]
268+
for add2reduce in [False]
239269
]
240270

241271

@@ -306,6 +336,26 @@ def _fma_f32x2(a, b, c):
306336
)
307337

308338

339+
@triton.jit
340+
def _reduce_fadd2(p0a, p1a, p0b, p1b):
341+
return tl.inline_asm_elementwise(
342+
"""
343+
{
344+
.reg .b64 rc, ra, rb;
345+
mov.b64 ra, { $2, $4 };
346+
mov.b64 rb, { $3, $5 };
347+
add.f32x2 rc, ra, rb;
348+
mov.b64 { $0, $1 }, rc;
349+
}
350+
""",
351+
"=r,=r,r,r,r,r",
352+
[p0a, p0b, p1a, p1b],
353+
dtype=[tl.float32, tl.float32],
354+
is_pure=True,
355+
pack=1,
356+
)
357+
358+
309359
@triton.jit
310360
def _attn_fwd_tma_dp(
311361
sm_scale,
@@ -328,8 +378,9 @@ def _attn_fwd_tma_dp(
328378
dtype: tl.constexpr,
329379
SUBTILING: tl.constexpr,
330380
VECT_MUL: tl.constexpr,
381+
FADD2_REDUCE: tl.constexpr,
331382
):
332-
tl.static_assert(BLOCK_N <= HEAD_DIM)
383+
# tl.static_assert(BLOCK_N <= HEAD_DIM)
333384
start_m = pid # tl.program_id(0)
334385
# off_hz = tl.program_id(1)
335386
off_z = off_hz // H
@@ -343,11 +394,11 @@ def _attn_fwd_tma_dp(
343394
offs_n = tl.arange(0, BLOCK_N)
344395

345396
m_i0 = tl.zeros([BLOCK_M // 2], dtype=tl.float32) - float("inf")
346-
l_i0 = tl.zeros([BLOCK_M // 2], dtype=tl.float32) + 1.0
397+
l_i0_0 = tl.zeros([BLOCK_M // 2], dtype=tl.float32) + 1.0
347398
acc0 = tl.zeros([BLOCK_M // 2, HEAD_DIM], dtype=tl.float32)
348399

349400
m_i1 = tl.zeros([BLOCK_M // 2], dtype=tl.float32) - float("inf")
350-
l_i1 = tl.zeros([BLOCK_M // 2], dtype=tl.float32) + 1.0
401+
l_i1_0 = tl.zeros([BLOCK_M // 2], dtype=tl.float32) + 1.0
351402
acc1 = tl.zeros([BLOCK_M // 2, HEAD_DIM], dtype=tl.float32)
352403

353404
qk_scale = sm_scale
@@ -356,12 +407,21 @@ def _attn_fwd_tma_dp(
356407
q0 = desc_q.load([qo_offset_y, 0])
357408
q1 = desc_q.load([qo_offset_y + BLOCK_M // 2, 0])
358409

410+
if FADD2_REDUCE:
411+
l_i0_1 = tl.zeros([BLOCK_M // 2], dtype=tl.float32)
412+
l_i1_1 = tl.zeros([BLOCK_M // 2], dtype=tl.float32)
413+
else:
414+
l_i0_1 = 0
415+
l_i1_1 = 0
416+
359417
if STAGE & 1:
360-
acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(
418+
acc0, acc1, l_i0_0, l_i0_1, l_i1_0, l_i1_1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(
361419
acc0,
362420
acc1,
363-
l_i0,
364-
l_i1,
421+
l_i0_0,
422+
l_i0_1,
423+
l_i1_0,
424+
l_i1_1,
365425
m_i0,
366426
m_i1,
367427
q0,
@@ -383,13 +443,16 @@ def _attn_fwd_tma_dp(
383443
warp_specialize,
384444
SUBTILING,
385445
VECT_MUL,
446+
FADD2_REDUCE,
386447
)
387448
if STAGE & 2:
388-
acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(
449+
acc0, acc1, l_i0_0, l_i0_1, l_i1_0, l_i1_1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(
389450
acc0,
390451
acc1,
391-
l_i0,
392-
l_i1,
452+
l_i0_0,
453+
l_i0_1,
454+
l_i1_0,
455+
l_i1_1,
393456
m_i0,
394457
m_i1,
395458
q0,
@@ -411,8 +474,16 @@ def _attn_fwd_tma_dp(
411474
warp_specialize,
412475
SUBTILING,
413476
VECT_MUL,
477+
FADD2_REDUCE,
414478
)
415479

480+
if FADD2_REDUCE:
481+
l_i0 = l_i0_0 + l_i0_1
482+
l_i1 = l_i1_0 + l_i1_1
483+
else:
484+
l_i0 = l_i0_0
485+
l_i1 = l_i1_0
486+
416487
m_i0 += tl.math.log2(l_i0)
417488
acc0 = acc0 / l_i0[:, None]
418489
m_ptrs0 = M + off_hz * N_CTX + offs_m0
@@ -451,6 +522,7 @@ def _attn_fwd(
451522
dtype: tl.constexpr,
452523
SUBTILING: tl.constexpr,
453524
VECT_MUL: tl.constexpr,
525+
FADD2_REDUCE: tl.constexpr,
454526
):
455527
pid = tl.program_id(0)
456528
off_hz = tl.program_id(1)
@@ -475,6 +547,7 @@ def _attn_fwd(
475547
dtype,
476548
SUBTILING,
477549
VECT_MUL,
550+
FADD2_REDUCE,
478551
)
479552

480553

@@ -493,7 +566,7 @@ def _attn_fwd_persist(
493566
desc_k,
494567
desc_v,
495568
desc_o,
496-
N_CTX, #
569+
N_CTX, #: tl.constexpr, #
497570
HEAD_DIM: tl.constexpr, #
498571
BLOCK_M: tl.constexpr, #
499572
BLOCK_N: tl.constexpr, #
@@ -504,6 +577,7 @@ def _attn_fwd_persist(
504577
dtype: tl.constexpr,
505578
SUBTILING: tl.constexpr,
506579
VECT_MUL: tl.constexpr,
580+
FADD2_REDUCE: tl.constexpr,
507581
):
508582
n_tile_num = tl.cdiv(N_CTX, BLOCK_M)
509583
prog_id = tl.program_id(0)
@@ -540,6 +614,7 @@ def _attn_fwd_persist(
540614
dtype,
541615
SUBTILING,
542616
VECT_MUL,
617+
FADD2_REDUCE,
543618
)
544619
tile_idx += num_progs
545620

0 commit comments

Comments
 (0)