Skip to content

Commit ed86b25

Browse files
committed
fix
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 18daa0d commit ed86b25

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

tritonbench/kernels/blackwell_triton_fused_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce):
239239
]
240240
else:
241241
# Helper to build config with optional minRegAutoWS/maxRegAutoWS
242-
def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce):
242+
def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg):
243243
config_kwargs = {
244244
"BLOCK_M": BM,
245245
"BLOCK_N": BN,
@@ -256,20 +256,21 @@ def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce):
256256
# Only add minRegAutoWS/maxRegAutoWS if supported (triton/tree/ws-3.5)
257257
if HAS_REG_AUTO_WS:
258258
extra_kwargs["minRegAutoWS"] = 24
259-
extra_kwargs["maxRegAutoWS"] = 152
259+
extra_kwargs["maxRegAutoWS"] = maxreg
260260
extra_kwargs["data_partition_factor"] = 2
261261

262262
return triton.Config(config_kwargs, **extra_kwargs)
263263

264264
configs = [
265-
make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce)
265+
make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg)
266266
for BM in [256]
267267
for BN in [64, 128]
268268
for s in NUM_STAGES_OPTIONS
269269
for w in [4]
270270
for subtile in [True]
271271
for vectmul in [1]
272272
for add2reduce in [False]
273+
for maxreg in [152, 192]
273274
]
274275

275276

@@ -384,7 +385,6 @@ def _attn_fwd_tma_dp(
384385
VECT_MUL: tl.constexpr,
385386
FADD2_REDUCE: tl.constexpr,
386387
):
387-
tl.static_assert(BLOCK_N <= HEAD_DIM)
388388
start_m = pid # tl.program_id(0)
389389
# off_hz = tl.program_id(1)
390390
off_z = off_hz // H
@@ -687,7 +687,7 @@ def grid_debug(META):
687687
):
688688
extra_kern_args["maxnreg"] = 128
689689
else:
690-
extra_kern_args["maxnreg"] = 80
690+
extra_kern_args["maxnreg"] = 128
691691
if persistent:
692692
_attn_fwd_persist[grid_persist](
693693
sm_scale,

tritonbench/kernels/blackwell_triton_fused_attention_dp.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _attn_fwd_subtile(
7272
qk -= m_ij[:, None]
7373
else:
7474
m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale)
75-
if VECT_MUL:
75+
if VECT_MUL == 2 or VECT_MUL == 3:
7676
qk = _fma_f32x2(qk, qk_scale, -m_ij[:, None])
7777
else:
7878
qk = qk * qk_scale - m_ij[:, None]
@@ -88,7 +88,7 @@ def _attn_fwd_subtile(
8888

8989
if SUBTILING:
9090
acc0, acc1 = acc.reshape([BM, 2, BN // 2]).permute(0, 2, 1).split()
91-
if VECT_MUL:
91+
if VECT_MUL == 1 or VECT_MUL == 3:
9292
acc0 = _mul_f32x2(acc0, alpha[:, None])
9393
acc1 = _mul_f32x2(acc1, alpha[:, None])
9494
else:
@@ -262,12 +262,12 @@ def make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce):
262262
for BN in [64, 128]
263263
for occ in [1, 2]
264264
for subtile in [True]
265-
for vectmul in [False]
265+
for vectmul in [0]
266266
for add2reduce in [False]
267267
]
268268
else:
269269
# Helper to build config with optional minRegAutoWS/maxRegAutoWS
270-
def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce):
270+
def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg):
271271
config_kwargs = {
272272
"BLOCK_M": BM,
273273
"BLOCK_N": BN,
@@ -284,19 +284,20 @@ def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce):
284284
# Only add minRegAutoWS/maxRegAutoWS if supported (triton/tree/ws-3.5)
285285
if HAS_REG_AUTO_WS:
286286
extra_kwargs["minRegAutoWS"] = 24
287-
extra_kwargs["maxRegAutoWS"] = 152
287+
extra_kwargs["maxRegAutoWS"] = maxreg
288288

289289
return triton.Config(config_kwargs, **extra_kwargs)
290290

291291
configs = [
292-
make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce)
292+
make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg)
293293
for BM in [256]
294294
for BN in [64, 128]
295295
for s in NUM_STAGES_OPTIONS
296296
for w in [4]
297297
for subtile in [True]
298-
for vectmul in [False]
298+
for vectmul in [1]
299299
for add2reduce in [False]
300+
for maxreg in [152, 192]
300301
]
301302

302303

0 commit comments

Comments
 (0)