@@ -2739,15 +2739,15 @@ def _persistent_reduction_configs(
27392739 if "y" not in size_hints :
27402740 configs = [
27412741 triton_config_reduction (size_hints , xblock , rnumel , register_intensive = True )
2742- for xblock in (1 , 8 , 32 , 128 )
2742+ for xblock in (1 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 )
27432743 if xblock == 1
27442744 or (xblock <= xnumel and rnumel * xblock <= 4096 )
27452745 ]
27462746 else :
27472747 configs = []
27482748 assert "tiling_scores" in inductor_meta
27492749 x_y_scores = {dim : inductor_meta ["tiling_scores" ][dim ] for dim in ("x" , "y" )}
2750- for target_block_size in (1 , 8 , 32 , 64 , 128 ):
2750+ for target_block_size in (1 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 ):
27512751 if target_block_size * rnumel > 4096 :
27522752 continue
27532753
@@ -2782,6 +2782,22 @@ def _persistent_reduction_configs(
27822782 for conf in tiny_configs :
27832783 if conf not in configs :
27842784 configs .append (conf )
2785+
2786+ # Expand configs to try additional warps
2787+ expanded_configs = []
2788+ for conf in configs :
2789+ num_warps = conf .num_warps
2790+ max_warps = 8 if torch .version .hip else 16
2791+ small_conf = copy .deepcopy (conf )
2792+ large_conf = copy .deepcopy (conf )
2793+ small_conf .num_warps = max (small_conf .num_warps // 2 , 1 )
2794+ large_conf .num_warps = min (large_conf .num_warps * 2 , max_warps )
2795+ expanded_configs .append (conf )
2796+ expanded_configs .append (small_conf )
2797+ expanded_configs .append (large_conf )
2798+
2799+ configs = expanded_configs
2800+
27852801 elif reduction_hint == ReductionHint .OUTER_TINY :
27862802 configs = tiny_configs
27872803
0 commit comments