@@ -530,7 +530,7 @@ def _dynamic_scale_rblock(self):
530530 # = regs_per_multiprocessor / (nreg * 32 * num_warps)
531531 # < regs_per_multiprocessor / ((regs_per_multiprocessor / max_threads_per_multi_processor) * 32 * num_warps)
532532 # = max_threads_per_multi_processor / (32 * num_warps)
533- # Using a tigher upper bound can reveal more optimization opportunities.
533+ # Using a tighter upper bound can reveal more optimization opportunities.
534534 max_blocks_per_sm = max (
535535 device_prop .regs_per_multiprocessor // nreg_per_block , 1
536536 )
@@ -3241,16 +3241,7 @@ def _persistent_reduction_configs(
32413241 "num_store" , 0
32423242 )
32433243
3244- max_autotune_enabled = not disable_pointwise_autotuning (inductor_meta ) or (
3245- inductor_meta .get ("max_autotune" )
3246- or inductor_meta .get ("max_autotune_pointwise" )
3247- )
3248-
3249- configs = [
3250- triton_config_reduction (size_hints , xblock , rnumel , register_intensive = True )
3251- for xblock in (1 , 8 , 32 , 128 )
3252- if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096 ))
3253- ]
3244+ MAX_PERSISTENT_BLOCK_NUMEL = 4096
32543245
32553246 if triton_meta .get ("native_matmul" ):
32563247 if len (size_hints ) == 3 :
@@ -3286,7 +3277,7 @@ def _persistent_reduction_configs(
32863277 )
32873278 for xblock in xblock_vals
32883279 if xblock == 1
3289- or (rnumel * xblock <= 4096 and xblock <= xnumel )
3280+ or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel )
32903281 ]
32913282 else :
32923283 configs = []
@@ -3559,29 +3550,20 @@ def user_autotune(
35593550 )
35603551
35613552
3562- def foreach (triton_meta , filename = None , inductor_meta = None ):
3553+ def foreach (triton_meta , num_warps , filename = None , inductor_meta = None ):
35633554 """
35643555 Compile a triton foreach kernel
35653556 """
3566- configs = []
3567- if disable_pointwise_autotuning (inductor_meta ) and not (
3568- inductor_meta .get ("max_autotune" ) or
3569- inductor_meta .get ("max_autotune_pointwise" )
3570- ):
3571- configs .append (triton .Config ({}, num_stages = 1 , num_warps = 8 ))
3572- else :
3573- for warps in [1 , 2 , 4 , 8 ]:
3574- configs .append (triton .Config ({}, num_stages = 1 , num_warps = warps ))
3575-
35763557 return cached_autotune (
35773558 None ,
3578- configs ,
3559+ [ triton . Config ({}, num_stages = 1 , num_warps = num_warps )] ,
35793560 triton_meta = triton_meta ,
35803561 inductor_meta = inductor_meta ,
35813562 heuristic_type = HeuristicType .TEMPLATE ,
35823563 filename = filename ,
35833564 )
35843565
3566+
35853567@dataclasses .dataclass
35863568class GridExpr :
35873569 """Generate code for grid size expressions in launcher"""
0 commit comments