Skip to content

Commit 240f4d1

Browse files
committed
fix
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 18daa0d commit 240f4d1

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tritonbench/kernels/blackwell_triton_fused_attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce):
239239
]
240240
else:
241241
# Helper to build config with optional minRegAutoWS/maxRegAutoWS
242-
def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce):
242+
def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg):
243243
config_kwargs = {
244244
"BLOCK_M": BM,
245245
"BLOCK_N": BN,
@@ -256,20 +256,21 @@ def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce):
256256
# Only add minRegAutoWS/maxRegAutoWS if supported (triton/tree/ws-3.5)
257257
if HAS_REG_AUTO_WS:
258258
extra_kwargs["minRegAutoWS"] = 24
259-
extra_kwargs["maxRegAutoWS"] = 152
259+
extra_kwargs["maxRegAutoWS"] = maxreg
260260
extra_kwargs["data_partition_factor"] = 2
261261

262262
return triton.Config(config_kwargs, **extra_kwargs)
263263

264264
configs = [
265-
make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce)
265+
make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce, maxreg)
266266
for BM in [256]
267267
for BN in [64, 128]
268268
for s in NUM_STAGES_OPTIONS
269269
for w in [4]
270270
for subtile in [True]
271271
for vectmul in [1]
272272
for add2reduce in [False]
273+
for maxreg in [152, 192]
273274
]
274275

275276

@@ -384,7 +385,6 @@ def _attn_fwd_tma_dp(
384385
VECT_MUL: tl.constexpr,
385386
FADD2_REDUCE: tl.constexpr,
386387
):
387-
tl.static_assert(BLOCK_N <= HEAD_DIM)
388388
start_m = pid # tl.program_id(0)
389389
# off_hz = tl.program_id(1)
390390
off_z = off_hz // H
@@ -687,7 +687,7 @@ def grid_debug(META):
687687
):
688688
extra_kern_args["maxnreg"] = 128
689689
else:
690-
extra_kern_args["maxnreg"] = 80
690+
extra_kern_args["maxnreg"] = 128
691691
if persistent:
692692
_attn_fwd_persist[grid_persist](
693693
sm_scale,

0 commit comments

Comments
 (0)