Skip to content

Commit 48b9910

Browse files
authored
[FA] add vectorization for fmul (#524)
1 parent 7183209 commit 48b9910

File tree

1 file changed

+73
-6
lines changed

1 file changed

+73
-6
lines changed

tritonbench/kernels/blackwell_triton_fused_attention_dp.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def _attn_fwd_subtile(
4444
dtype: tl.constexpr,
4545
STAGE: tl.constexpr,
4646
SUBTILING: tl.constexpr,
47+
VECT_MUL: tl.constexpr,
4748
):
4849
qk = tl.dot(q, k)
4950
if STAGE == 2:
@@ -53,7 +54,10 @@ def _attn_fwd_subtile(
5354
qk -= m_ij[:, None]
5455
else:
5556
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
56-
qk = qk * qk_scale - m_ij[:, None]
57+
if VECT_MUL:
58+
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None])
59+
else:
60+
qk = qk * qk_scale - m_ij[:, None]
5761
p = tl.math.exp2(qk)
5862
# -- compute correction factor
5963
alpha = tl.math.exp2(m_i - m_ij)
@@ -65,8 +69,12 @@ def _attn_fwd_subtile(
6569

6670
if SUBTILING:
6771
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
68-
acc0 = acc0 * alpha[:, None]
69-
acc1 = acc1 * alpha[:, None]
72+
if VECT_MUL:
73+
acc0 = _mul_f32x2(acc0, alpha[:, None])
74+
acc1 = _mul_f32x2(acc1, alpha[:, None])
75+
else:
76+
acc0 = acc0 * alpha[:, None]
77+
acc1 = acc1 * alpha[:, None]
7078
acc = tl.join(acc0, acc1).permute(0, 2, 1).reshape([BM, BN])
7179
else:
7280
acc = acc * alpha[:, None]
@@ -109,6 +117,7 @@ def _attn_fwd_inner_oss_dp(
109117
N_CTX: tl.constexpr,
110118
warp_specialize: tl.constexpr,
111119
SUBTILING: tl.constexpr,
120+
VECT_MUL: tl.constexpr,
112121
):
113122
# range of values handled by this stage
114123
if STAGE == 1:
@@ -144,6 +153,7 @@ def _attn_fwd_inner_oss_dp(
144153
dtype,
145154
STAGE,
146155
SUBTILING,
156+
VECT_MUL,
147157
)
148158
l_i1, m_i1, acc1 = _attn_fwd_subtile(
149159
q1,
@@ -159,6 +169,7 @@ def _attn_fwd_inner_oss_dp(
159169
dtype,
160170
STAGE,
161171
SUBTILING,
172+
VECT_MUL,
162173
)
163174

164175
offsetkv_y += BLOCK_N
@@ -191,18 +202,25 @@ def _host_descriptor_pre_hook(nargs):
191202
if is_tile_enabled():
192203
configs = [
193204
triton.Config(
194-
{"BLOCK_M": BM, "BLOCK_N": BN, "occupancy": occ, "SUBTILING": subtile},
205+
{
206+
"BLOCK_M": BM,
207+
"BLOCK_N": BN,
208+
"occupancy": occ,
209+
"SUBTILING": subtile,
210+
"VECT_MUL": vectmul,
211+
},
195212
pre_hook=_host_descriptor_pre_hook,
196213
)
197214
for BM in [64, 128, 256]
198215
for BN in [64, 128]
199216
for occ in [1, 2]
200217
for subtile in [True]
218+
for vectmul in [True]
201219
]
202220
else:
203221
configs = [
204222
triton.Config(
205-
{"BLOCK_M": BM, "BLOCK_N": BN, "SUBTILING": subtile},
223+
{"BLOCK_M": BM, "BLOCK_N": BN, "SUBTILING": subtile, "VECT_MUL": vectmul},
206224
num_stages=s,
207225
num_warps=w,
208226
pre_hook=_host_descriptor_pre_hook,
@@ -212,7 +230,8 @@ def _host_descriptor_pre_hook(nargs):
212230
for BN in [128]
213231
for s in NUM_STAGES_OPTIONS
214232
for w in [4]
215-
for subtile in [False] # disable subtiling for now
233+
for subtile in [True]
234+
for vectmul in [False]
216235
]
217236

218237

@@ -242,6 +261,47 @@ def _maybe_make_tensor_desc(desc_or_ptr, shape, strides, block_shape):
242261
return tl.make_tensor_descriptor(desc_or_ptr, shape, strides, block_shape)
243262

244263

264+
@triton.jit
265+
def _mul_f32x2(a, b):
266+
return tl.inline_asm_elementwise(
267+
"""
268+
{
269+
.reg .b64 ra, rb, rc;
270+
mov.b64 ra, { $2, $3 };
271+
mov.b64 rb, { $4, $5 };
272+
mul.f32x2 rc, ra, rb;
273+
mov.b64 { $0, $1 }, rc;
274+
}
275+
""",
276+
"=r,=r,r,r,r,r",
277+
[a, b],
278+
dtype=tl.float32,
279+
is_pure=True,
280+
pack=2,
281+
)
282+
283+
284+
@triton.jit
285+
def _fma_f32x2(a, b, c):
286+
return tl.inline_asm_elementwise(
287+
"""
288+
{
289+
.reg .b64 ra, rb, rc, rd;
290+
mov.b64 ra, { $2, $3 };
291+
mov.b64 rb, { $4, $5 };
292+
mov.b64 rc, { $6, $7 };
293+
fma.rn.f32x2 rd, ra, rb, rc;
294+
mov.b64 { $0, $1 }, rd;
295+
}
296+
""",
297+
"=r,=r,r,r,r,r,r,r",
298+
[a, b, c],
299+
dtype=tl.float32,
300+
is_pure=True,
301+
pack=2,
302+
)
303+
304+
245305
@triton.jit
246306
def _attn_fwd_tma_dp(
247307
sm_scale,
@@ -263,6 +323,7 @@ def _attn_fwd_tma_dp(
263323
warp_specialize: tl.constexpr, #
264324
dtype: tl.constexpr,
265325
SUBTILING: tl.constexpr,
326+
VECT_MUL: tl.constexpr,
266327
):
267328
tl.static_assert(BLOCK_N <= HEAD_DIM)
268329
start_m = pid # tl.program_id(0)
@@ -317,6 +378,7 @@ def _attn_fwd_tma_dp(
317378
N_CTX, #
318379
warp_specialize,
319380
SUBTILING,
381+
VECT_MUL,
320382
)
321383
if STAGE & 2:
322384
acc0, acc1, l_i0, l_i1, m_i0, m_i1 = _attn_fwd_inner_oss_dp(
@@ -344,6 +406,7 @@ def _attn_fwd_tma_dp(
344406
N_CTX, #
345407
warp_specialize,
346408
SUBTILING,
409+
VECT_MUL,
347410
)
348411

349412
m_i0 += tl.math.log2(l_i0)
@@ -383,6 +446,7 @@ def _attn_fwd(
383446
warp_specialize: tl.constexpr, #
384447
dtype: tl.constexpr,
385448
SUBTILING: tl.constexpr,
449+
VECT_MUL: tl.constexpr,
386450
):
387451
pid = tl.program_id(0)
388452
off_hz = tl.program_id(1)
@@ -406,6 +470,7 @@ def _attn_fwd(
406470
warp_specialize,
407471
dtype,
408472
SUBTILING,
473+
VECT_MUL,
409474
)
410475

411476

@@ -434,6 +499,7 @@ def _attn_fwd_persist(
434499
OUTER_LOOP: tl.constexpr,
435500
dtype: tl.constexpr,
436501
SUBTILING: tl.constexpr,
502+
VECT_MUL: tl.constexpr,
437503
):
438504
n_tile_num = tl.cdiv(N_CTX, BLOCK_M)
439505
prog_id = tl.program_id(0)
@@ -469,6 +535,7 @@ def _attn_fwd_persist(
469535
warp_specialize and not OUTER_LOOP,
470536
dtype,
471537
SUBTILING,
538+
VECT_MUL,
472539
)
473540
tile_idx += num_progs
474541

0 commit comments

Comments
 (0)