@@ -2689,21 +2689,30 @@ def _persistent_reduction_configs(
26892689 xnumel = size_hints ["x" ]
26902690 rnumel = get_total_reduction_numel (size_hints )
26912691
2692- MAX_PERSISTENT_BLOCK_NUMEL = 4096
2692+ max_autotune_enabled = not disable_pointwise_autotuning (inductor_meta ) or (
2693+ inductor_meta .get ("max_autotune" )
2694+ or inductor_meta .get ("max_autotune_pointwise" )
2695+ )
26932696
2697+ configs = [
2698+ triton_config_reduction (size_hints , xblock , rnumel , register_intensive = True )
2699+ for xblock in (1 , 8 , 32 , 128 )
2700+ if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096 ))
2701+ ]
2702+
26942703 if "y" not in size_hints :
26952704 configs = [
26962705 triton_config_reduction (size_hints , xblock , rnumel , register_intensive = True )
26972706 for xblock in (1 , 8 , 32 , 128 )
26982707 if xblock == 1
2699- or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel )
2708+ or (rnumel * xblock <= 4096 and xblock <= xnumel )
27002709 ]
27012710 else :
27022711 configs = []
27032712 assert "tiling_scores" in inductor_meta
27042713 x_y_scores = {dim : inductor_meta ["tiling_scores" ][dim ] for dim in ("x" , "y" )}
27052714 for target_block_size in (1 , 8 , 32 , 64 , 128 ):
2706- if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL :
2715+ if target_block_size * rnumel > 4096 :
27072716 continue
27082717
27092718 block_sizes = match_target_block_product (
@@ -2718,19 +2727,28 @@ def _persistent_reduction_configs(
27182727 # defer to more autotuning, initially
27192728 if "y" in size_hints :
27202729 pass
2721- # TODO(jansel): we should be able to improve these heuristics
2722- elif reduction_hint == ReductionHint .INNER and rnumel >= 256 :
2723- configs = configs [:1 ]
2724- elif reduction_hint == ReductionHint .OUTER :
2725- configs = configs [- 1 :]
2726- elif reduction_hint == ReductionHint .OUTER_TINY :
2727- configs = [
2730+
2731+ if not max_autotune_enabled : # Don't filter if tuning enabled
2732+ if reduction_hint == ReductionHint .INNER and rnumel >= 256 :
2733+ configs = configs [:1 ]
2734+ elif reduction_hint == ReductionHint .OUTER :
2735+ configs = configs [- 1 :]
2736+
2737+ if reduction_hint == ReductionHint .OUTER_TINY :
2738+ tiny_configs = [
27282739 triton_config_reduction (
27292740 size_hints ,
27302741 2 * (256 // rnumel ) if rnumel <= 256 else 1 ,
27312742 rnumel ,
27322743 )
27332744 ]
2745+ if max_autotune_enabled :
2746+ for tconfig in tiny_configs :
2747+ if tconfig not in configs :
2748+ configs .append (tconfig )
2749+ else :
2750+ configs = tiny_configs
2751+
27342752 for c in configs :
27352753 # we don't need Rn_BLOCK for persistent reduction
27362754 for prefix in size_hints :
@@ -2922,20 +2940,29 @@ def user_autotune(
29222940 )
29232941
29242942
2925- def foreach (triton_meta , num_warps , filename = None , inductor_meta = None ):
2943+ def foreach (triton_meta , filename = None , inductor_meta = None ):
29262944 """
29272945 Compile a triton foreach kernel
29282946 """
2947+ configs = []
2948+ if disable_pointwise_autotuning (inductor_meta ) and not (
2949+ inductor_meta .get ("max_autotune" ) or
2950+ inductor_meta .get ("max_autotune_pointwise" )
2951+ ):
2952+ configs .append (triton .Config ({}, num_stages = 1 , num_warps = 8 ))
2953+ else :
2954+ for warps in [1 , 2 , 4 , 8 ]:
2955+ configs .append (triton .Config ({}, num_stages = 1 , num_warps = warps ))
2956+
29292957 return cached_autotune (
29302958 None ,
2931- [ triton . Config ({}, num_stages = 1 , num_warps = num_warps )] ,
2959+ configs ,
29322960 triton_meta = triton_meta ,
29332961 inductor_meta = inductor_meta ,
29342962 heuristic_type = HeuristicType .TEMPLATE ,
29352963 filename = filename ,
29362964 )
29372965
2938-
29392966@dataclasses .dataclass
29402967class GridExpr :
29412968 """Generate code for grid size expressions in launcher"""
0 commit comments