@@ -2478,7 +2478,7 @@ def pointwise(
24782478
24792479
24802480def _reduction_configs (
2481- * , size_hints : dict [str , int ], inductor_meta : dict [str , Any ], num_dynamic = 0
2481+ * , size_hints : dict [str , int ], inductor_meta : dict [str , Any ]
24822482) -> list [Config ]:
24832483 reduction_hint = inductor_meta .get ("reduction_hint" , None )
24842484
@@ -2531,68 +2531,17 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
25312531 register_intensive = register_intensive ,
25322532 )
25332533
2534- def outer_config_opt ():
2535- # Default to 64 for vectorized loads
2536- max_x_block , x_block = 256 , 64
2537- load_factor = inductor_meta .get ("num_load" , 0 )
2538- x = size_hints ["x" ]
2539- num_warps = None
2540-
2541- # Try to use all SMs with small x
2542- if x <= 1024 :
2543- x_block = max (min (x // 128 , 8 ), 2 )
2544- outer_r_block = min (rnumel , 64 )
2545- # Lower bound x = 1024, 1024 // 16 = 128 around # of SMs
2546- elif x // 4096 <= 8 :
2547- x_block = 16
2548- outer_r_block = 512 // x_block
2549- elif num_dynamic > 1 :
2550- # Lots of compute with multiple dynamic shape per loop iteration
2551- # Larger RBLOCK minimizes loop iteration
2552- outer_r_block = max (min ((rnumel // 64 ), 64 ), 8 )
2553- elif num_dynamic == 1 :
2554- # Dynamic shapes introduce a lot register pressure for indexing
2555- outer_r_block = (
2556- 1
2557- if load_factor >= 3
2558- else min (next_power_of_2 (max (rnumel , 128 ) // 128 ), 8 )
2559- )
2560- else :
2561- x_block = max (min (max_x_block , next_power_of_2 (x // 4096 )), x_block )
2562- if load_factor < 4 or rnumel <= 128 :
2563- outer_r_block = 512 // x_block
2564- else :
2565- # Heavier reductions contain a lot more overhead per loop iteration
2566- # We minimize the overhead by enlarging r block
2567- if rnumel >= 2048 :
2568- outer_r_block = 64
2569- else :
2570- outer_r_block = 32
2571- x_block = min (x_block , 32 )
2572- num_warps = 4
2573-
2574- # Set register intensive to true by default as we try to maximize tiles with heuristic
2575- return make_config (
2576- x_block ,
2577- outer_r_block ,
2578- num_warps = num_warps ,
2579- register_intensive = register_intensive ,
2580- )
2581-
25822534 contiguous_config = make_config (
25832535 1 ,
25842536 min (rnumel , MAX_R0_BLOCK ),
25852537 register_intensive = register_intensive ,
25862538 )
2539+ outer_config = make_config (64 , 8 , register_intensive = register_intensive )
25872540 tiny_config = make_config (
25882541 2 * (256 // rnumel ) if rnumel <= 256 else 1 ,
25892542 min (rnumel , MAX_R0_BLOCK ),
25902543 register_intensive = register_intensive ,
25912544 )
2592-
2593- outer_config = make_config (64 , 8 , register_intensive = register_intensive )
2594- if not torch .version .hip :
2595- outer_config = outer_config_opt ()
25962545 # For 3d tiling, default to more autotuning initially
25972546 if "y" in size_hints :
25982547 pass
@@ -2712,15 +2661,7 @@ def reduction(
27122661
27132662 assert triton_meta is not None
27142663
2715- num_dynamic = 0
2716- for k in triton_meta ["signature" ].keys ():
2717- if "ks" in k :
2718- num_dynamic += 1
2719-
2720- configs = _reduction_configs (
2721- size_hints = size_hints , inductor_meta = inductor_meta , num_dynamic = num_dynamic
2722- )
2723-
2664+ configs = _reduction_configs (size_hints = size_hints , inductor_meta = inductor_meta )
27242665 configs = _maybe_filter_configs_for_tma_restrictions (inductor_meta , configs )
27252666 return cached_autotune (
27262667 size_hints ,
0 commit comments