Skip to content

Commit 41b3d6f

Browse files
authored
[autoWS] add pingpong and subtiling_p tuning knobs in blackwell tutorial FA (#648)
1 parent 2d8ee48 commit 41b3d6f

File tree

1 file changed

+65
-9
lines changed

1 file changed

+65
-9
lines changed

tritonbench/kernels/blackwell_triton_fused_attention_dp.py

Lines changed: 65 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,22 @@ def _supports_reg_auto_ws():
4949
HAS_REG_AUTO_WS = _supports_reg_auto_ws()
5050

5151

52+
# Check if Triton version supports pingpongAutoWS
53+
# These parameters are only available in triton/tree/ws-3.5
54+
def _supports_pingpong_auto_ws():
55+
"""Check if the current Triton version supports pingpongAutoWS"""
56+
try:
57+
# Try to create a Config with minRegAutoWS to test support
58+
test_config = triton.Config({}, pingpongAutoWS=True)
59+
return True
60+
except (TypeError, AttributeError):
61+
# Parameter not supported in this Triton version
62+
return False
63+
64+
65+
HAS_PINGPONG_AUTO_WS = _supports_pingpong_auto_ws()
66+
67+
5268
@triton.jit
5369
def _attn_fwd_subtile(
5470
q,
@@ -65,6 +81,7 @@ def _attn_fwd_subtile(
6581
dtype: tl.constexpr,
6682
STAGE: tl.constexpr,
6783
SUBTILING: tl.constexpr,
84+
SUBTILING_P: tl.constexpr,
6885
VECT_MUL: tl.constexpr,
6986
FADD2_REDUCE: tl.constexpr,
7087
):
@@ -80,7 +97,23 @@ def _attn_fwd_subtile(
8097
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None])
8198
else:
8299
qk = qk * qk_scale - m_ij[:, None]
83-
p = tl.math.exp2(qk)
100+
101+
PM: tl.constexpr = qk.shape[0]
102+
PN: tl.constexpr = qk.shape[1]
103+
104+
if SUBTILING_P:
105+
qk0, qk1 = qk.reshape([PM, 2, PN // 2]).permute(0, 2, 1).split()
106+
107+
p0 = tl.math.exp2(qk0)
108+
p0_bf16 = p0.to(dtype)
109+
p1 = tl.math.exp2(qk1)
110+
p1_bf16 = p1.to(dtype)
111+
112+
p = tl.join(p0, p1).permute(0, 2, 1).reshape([PM, PN])
113+
p_bf16 = tl.join(p0_bf16, p1_bf16).permute(0, 2, 1).reshape([PM, PN])
114+
else:
115+
p = tl.math.exp2(qk)
116+
84117
# -- compute correction factor
85118
alpha = tl.math.exp2(m_i - m_ij)
86119
if not FADD2_REDUCE:
@@ -104,8 +137,6 @@ def _attn_fwd_subtile(
104137

105138
# update m_i and l_i
106139
# place this at the end of the loop to reduce register pressure
107-
PM: tl.constexpr = p.shape[0]
108-
PN: tl.constexpr = p.shape[1]
109140
if FADD2_REDUCE:
110141
p0, p1 = p.reshape([PM, 2, PN // 2]).permute(0, 2, 1).split()
111142
l_ij0, l_ij1 = tl.reduce((p0, p1), axis=1, combine_fn=_reduce_fadd2)
@@ -115,9 +146,10 @@ def _attn_fwd_subtile(
115146
# We can potentially move these to be before updating l_ij, so the dot
116147
# is not blocked.
117148
# prepare p and v for the dot
118-
p = p.to(dtype)
149+
if not SUBTILING_P:
150+
p_bf16 = p.to(dtype)
119151
# note that this non transposed v for FP8 is only supported on Blackwell
120-
acc = tl.dot(p, v, acc)
152+
acc = tl.dot(p_bf16, v, acc)
121153
if not FADD2_REDUCE:
122154
l_i0 = l_i0 * alpha + l_ij
123155
m_i = m_ij
@@ -153,6 +185,7 @@ def _attn_fwd_inner_oss_dp(
153185
N_CTX: tl.constexpr,
154186
warp_specialize: tl.constexpr,
155187
SUBTILING: tl.constexpr,
188+
SUBTILING_P: tl.constexpr,
156189
VECT_MUL: tl.constexpr,
157190
FADD2_REDUCE: tl.constexpr,
158191
):
@@ -191,6 +224,7 @@ def _attn_fwd_inner_oss_dp(
191224
dtype,
192225
STAGE,
193226
SUBTILING,
227+
SUBTILING_P,
194228
VECT_MUL,
195229
FADD2_REDUCE,
196230
)
@@ -209,6 +243,7 @@ def _attn_fwd_inner_oss_dp(
209243
dtype,
210244
STAGE,
211245
SUBTILING,
246+
SUBTILING_P,
212247
VECT_MUL,
213248
FADD2_REDUCE,
214249
)
@@ -242,12 +277,13 @@ def _host_descriptor_pre_hook(nargs):
242277

243278
if is_tile_enabled():
244279
# Helper to build config with optional minRegAutoWS/maxRegAutoWS
245-
def make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce):
280+
def make_tile_config(BM, BN, occ, subtile, subtile_p, vectmul, add2reduce):
246281
config_kwargs = {
247282
"BLOCK_M": BM,
248283
"BLOCK_N": BN,
249284
"occupancy": occ,
250285
"SUBTILING": subtile,
286+
"SUBTILING_P": subtile_p,
251287
"VECT_MUL": vectmul,
252288
"FADD2_REDUCE": add2reduce,
253289
}
@@ -258,24 +294,31 @@ def make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce):
258294
extra_kwargs["minRegAutoWS"] = 24
259295
extra_kwargs["maxRegAutoWS"] = 152
260296

297+
if HAS_PINGPONG_AUTO_WS:
298+
extra_kwargs["pingpongAutoWS"] = True
299+
261300
return triton.Config(config_kwargs, **extra_kwargs)
262301

263302
configs = [
264-
make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce)
303+
make_tile_config(BM, BN, occ, subtile, subtile_p, vectmul, add2reduce)
265304
for BM in [64, 128, 256]
266305
for BN in [64, 128]
267306
for occ in [1, 2]
268307
for subtile in [True]
308+
for subtile_p in [True]
269309
for vectmul in [0]
270310
for add2reduce in [False]
271311
]
272312
else:
273313
# Helper to build config with optional minRegAutoWS/maxRegAutoWS
274-
def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg):
314+
def make_standard_config(
315+
BM, BN, s, w, subtile, subtile_p, vectmul, add2reduce, maxreg
316+
):
275317
config_kwargs = {
276318
"BLOCK_M": BM,
277319
"BLOCK_N": BN,
278320
"SUBTILING": subtile,
321+
"SUBTILING_P": subtile_p,
279322
"VECT_MUL": vectmul,
280323
"FADD2_REDUCE": add2reduce,
281324
}
@@ -290,15 +333,21 @@ def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg):
290333
extra_kwargs["minRegAutoWS"] = 24
291334
extra_kwargs["maxRegAutoWS"] = maxreg
292335

336+
if HAS_PINGPONG_AUTO_WS:
337+
extra_kwargs["pingpongAutoWS"] = True
338+
293339
return triton.Config(config_kwargs, **extra_kwargs)
294340

295341
configs = [
296-
make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg)
342+
make_standard_config(
343+
BM, BN, s, w, subtile, subtile_p, vectmul, add2reduce, maxreg
344+
)
297345
for BM in [256]
298346
for BN in [64, 128]
299347
for s in NUM_STAGES_OPTIONS
300348
for w in [4]
301349
for subtile in [True]
350+
for subtile_p in [True]
302351
for vectmul in [1]
303352
for add2reduce in [False]
304353
for maxreg in [152, 192]
@@ -420,6 +469,7 @@ def _attn_fwd_tma_dp(
420469
warp_specialize: tl.constexpr, #
421470
dtype: tl.constexpr,
422471
SUBTILING: tl.constexpr,
472+
SUBTILING_P: tl.constexpr,
423473
VECT_MUL: tl.constexpr,
424474
FADD2_REDUCE: tl.constexpr,
425475
):
@@ -485,6 +535,7 @@ def _attn_fwd_tma_dp(
485535
N_CTX, #
486536
warp_specialize,
487537
SUBTILING,
538+
SUBTILING_P,
488539
VECT_MUL,
489540
FADD2_REDUCE,
490541
)
@@ -516,6 +567,7 @@ def _attn_fwd_tma_dp(
516567
N_CTX, #
517568
warp_specialize,
518569
SUBTILING,
570+
SUBTILING_P,
519571
VECT_MUL,
520572
FADD2_REDUCE,
521573
)
@@ -564,6 +616,7 @@ def _attn_fwd(
564616
warp_specialize: tl.constexpr, #
565617
dtype: tl.constexpr,
566618
SUBTILING: tl.constexpr,
619+
SUBTILING_P: tl.constexpr,
567620
VECT_MUL: tl.constexpr,
568621
FADD2_REDUCE: tl.constexpr,
569622
):
@@ -589,6 +642,7 @@ def _attn_fwd(
589642
warp_specialize,
590643
dtype,
591644
SUBTILING,
645+
SUBTILING_P,
592646
VECT_MUL,
593647
FADD2_REDUCE,
594648
)
@@ -619,6 +673,7 @@ def _attn_fwd_persist(
619673
OUTER_LOOP: tl.constexpr,
620674
dtype: tl.constexpr,
621675
SUBTILING: tl.constexpr,
676+
SUBTILING_P: tl.constexpr,
622677
VECT_MUL: tl.constexpr,
623678
FADD2_REDUCE: tl.constexpr,
624679
):
@@ -682,6 +737,7 @@ def _attn_fwd_persist(
682737
warp_specialize and not OUTER_LOOP,
683738
dtype,
684739
SUBTILING,
740+
SUBTILING_P,
685741
VECT_MUL,
686742
FADD2_REDUCE,
687743
)

0 commit comments

Comments
 (0)