@@ -239,7 +239,7 @@ def make_tile_config(BM, BN, occ, subtile, vectmul, add2reduce):
239
239
]
240
240
else :
241
241
# Helper to build config with optional minRegAutoWS/maxRegAutoWS
242
- def make_standard_config (BM , BN , s , w , subtile , vectmul , add2reduce ):
242
+ def make_standard_config (BM , BN , s , w , subtile , vectmul , add2reduce , maxreg ):
243
243
config_kwargs = {
244
244
"BLOCK_M" : BM ,
245
245
"BLOCK_N" : BN ,
@@ -256,20 +256,21 @@ def make_standard_config(BM, BN, s, w, subtile, vectmul, add2reduce):
256
256
# Only add minRegAutoWS/maxRegAutoWS if supported (triton/tree/ws-3.5)
257
257
if HAS_REG_AUTO_WS :
258
258
extra_kwargs ["minRegAutoWS" ] = 24
259
- extra_kwargs ["maxRegAutoWS" ] = 152
259
+ extra_kwargs ["maxRegAutoWS" ] = maxreg
260
260
extra_kwargs ["data_partition_factor" ] = 2
261
261
262
262
return triton .Config (config_kwargs , ** extra_kwargs )
263
263
264
264
configs = [
265
- make_standard_config (BM , BN , s , w , subtile , vectmul , add2reduce )
265
+ make_standard_config (BM , BN , s , w , subtile , vectmul , add2reduce , maxreg )
266
266
for BM in [256 ]
267
267
for BN in [64 , 128 ]
268
268
for s in NUM_STAGES_OPTIONS
269
269
for w in [4 ]
270
270
for subtile in [True ]
271
271
for vectmul in [1 ]
272
272
for add2reduce in [False ]
273
+ for maxreg in [152 , 192 ]
273
274
]
274
275
275
276
@@ -384,7 +385,6 @@ def _attn_fwd_tma_dp(
384
385
VECT_MUL : tl .constexpr ,
385
386
FADD2_REDUCE : tl .constexpr ,
386
387
):
387
- tl .static_assert (BLOCK_N <= HEAD_DIM )
388
388
start_m = pid # tl.program_id(0)
389
389
# off_hz = tl.program_id(1)
390
390
off_z = off_hz // H
@@ -687,7 +687,7 @@ def grid_debug(META):
687
687
):
688
688
extra_kern_args ["maxnreg" ] = 128
689
689
else :
690
- extra_kern_args ["maxnreg" ] = 80
690
+ extra_kern_args ["maxnreg" ] = 128
691
691
if persistent :
692
692
_attn_fwd_persist [grid_persist ](
693
693
sm_scale ,
0 commit comments