@@ -615,7 +615,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs):
615615 # for some (complicated) custom Triton kernels, a register-spilling
616616 # config may yield the best latency.
617617 if not self .custom_kernel and launcher .n_spills > self .inductor_meta .get (
618- "spill_threshold" , 16
618+ "spill_threshold" , 32 if torch . version . hip else 16
619619 ):
620620 log .debug (
621621 "Skip config %s because of register spilling: %d" ,
@@ -1779,6 +1779,8 @@ def triton_config_reduction(
17791779 num_stages = 1 ,
17801780 num_warps = None ,
17811781 register_intensive = False ,
1782+ waves_per_eu = None ,
1783+ dynamic_scale_rblock = True ,
17821784) -> Config :
17831785 """
17841786 Construct a reduction triton config with some adjustment heuristics
@@ -1822,7 +1824,18 @@ def total_numel() -> int:
18221824 cfg = _get_config ({"x" : x , ** rnumels })
18231825 check_max_block (cfg )
18241826 check_config (cfg , xnumel = size_hints ["x" ])
1825- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
1827+ config = InductorConfig (
1828+ cfg ,
1829+ num_warps = num_warps ,
1830+ num_stages = num_stages ,
1831+ dynamic_scale_rblock = dynamic_scale_rblock ,
1832+ )
1833+
1834+ if torch .version .hip :
1835+ if waves_per_eu is not None :
1836+ config .kwargs ["waves_per_eu" ] = waves_per_eu
1837+
1838+ return config
18261839
18271840
18281841def _get_config (numels : dict [str , int ]) -> dict [str , int ]:
@@ -1833,7 +1846,9 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:
18331846 return {prefix .upper () + "BLOCK" : numel for prefix , numel in numels .items ()}
18341847
18351848
1836- def triton_config_tiled_reduction (size_hints , x , y , r , num_stages = 1 ):
1849+ def triton_config_tiled_reduction (
1850+ size_hints , x , y , r , num_stages = 1 , register_intensive = False , waves_per_eu = None
1851+ ):
18371852 """
18381853 Construct a tile reduction triton config with some adjustment
18391854 heuristics based on size_hints. Size_hints is a tuple of numels in
@@ -1866,7 +1881,11 @@ def total_numel() -> int:
18661881 num_warps = _num_warps (total_numel () // 256 , min_num_warps = 1 )
18671882 check_config (cfg , xnumel = size_hints [0 ], ynumel = size_hints [1 ])
18681883 check_max_block (cfg )
1869- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
1884+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
1885+ if torch .version .hip :
1886+ if waves_per_eu is not None :
1887+ config .kwargs ["waves_per_eu" ] = waves_per_eu
1888+ return config
18701889
18711890
18721891def pointwise (
@@ -1992,6 +2011,11 @@ def _reduction_configs(
19922011 # Convert reductions to 1D, to simplify heuristics.
19932012 rnumel = get_total_reduction_numel (size_hints )
19942013
2014+ # Is max autotune enabled
2015+ max_autotune_enabled = inductor_meta .get ("max_autotune" ) or inductor_meta .get (
2016+ "max_autotune_pointwise"
2017+ )
2018+
19952019 register_intensive = False
19962020 MAX_R0_BLOCK = 2048
19972021 if (
@@ -2014,8 +2038,90 @@ def _reduction_configs(
20142038 MAX_R0_BLOCK = 1024
20152039 register_intensive = True
20162040
2017- contiguous_config = triton_config_reduction (
2018- size_hints ,
2041+ def make_config (
2042+ x ,
2043+ r ,
2044+ num_warps = None ,
2045+ num_stages = 1 ,
2046+ register_intensive = False ,
2047+ dynamic_scale_rblock = True ,
2048+ waves_per_eu = None ,
2049+ ):
2050+ # For 3D case with tiling scores, create an adapted version
2051+ if "y" in size_hints :
2052+ assert "tiling_scores" in inductor_meta
2053+ return adapt_config_for_tiling (
2054+ size_hints ,
2055+ inductor_meta ["tiling_scores" ],
2056+ x ,
2057+ r ,
2058+ num_warps = num_warps ,
2059+ num_stages = num_stages ,
2060+ register_intensive = register_intensive ,
2061+ waves_per_eu = waves_per_eu ,
2062+ )
2063+ else :
2064+ # For other cases, use the original function
2065+ return triton_config_reduction (
2066+ size_hints ,
2067+ x ,
2068+ r ,
2069+ num_warps = num_warps ,
2070+ num_stages = num_stages ,
2071+ register_intensive = register_intensive ,
2072+ waves_per_eu = waves_per_eu ,
2073+ dynamic_scale_rblock = dynamic_scale_rblock ,
2074+ )
2075+
2076+ def outer_config_opt ():
2077+ # Default to 64 for vectorized loads
2078+ max_x_block , x_block = 256 , 64
2079+ load_factor = inductor_meta .get ("num_load" , 0 )
2080+ x = size_hints ["x" ]
2081+ num_warps = None
2082+
2083+ # Try to use all SMs with small x
2084+ if x <= 1024 :
2085+ x_block = max (min (x // 128 , 8 ), 2 )
2086+ outer_r_block = min (rnumel , 64 )
2087+ # Lower bound x = 1024, 1024 // 16 = 128 around # of SMs
2088+ elif x // 4096 <= 8 :
2089+ x_block = 16
2090+ outer_r_block = 512 // x_block
2091+ elif num_dynamic > 1 :
2092+ # Lots of compute with multiple dynamic shape per loop iteration
2093+ # Larger RBLOCK minimizes loop iteration
2094+ outer_r_block = max (min ((rnumel // 64 ), 64 ), 8 )
2095+ elif num_dynamic == 1 :
2096+ # Dynamic shapes introduce a lot register pressure for indexing
2097+ outer_r_block = (
2098+ 1
2099+ if load_factor >= 3
2100+ else min (next_power_of_2 (max (rnumel , 128 ) // 128 ), 8 )
2101+ )
2102+ else :
2103+ x_block = max (min (max_x_block , next_power_of_2 (x // 4096 )), x_block )
2104+ if load_factor < 4 or rnumel <= 128 :
2105+ outer_r_block = 512 // x_block
2106+ else :
2107+ # Heavier reductions contain a lot more overhead per loop iteration
2108+ # We minimize the overhead by enlarging r block
2109+ if rnumel >= 2048 :
2110+ outer_r_block = 64
2111+ else :
2112+ outer_r_block = 32
2113+ x_block = min (x_block , 32 )
2114+ num_warps = 4
2115+
2116+ # Set register intensive to true by default as we try to maximize tiles with heuristic
2117+ return make_config (
2118+ x_block ,
2119+ outer_r_block ,
2120+ num_warps = num_warps ,
2121+ register_intensive = register_intensive ,
2122+ )
2123+
2124+ contiguous_config = make_config (
20192125 1 ,
20202126 rnumel if 256 <= rnumel < MAX_R0_BLOCK else MAX_R0_BLOCK ,
20212127 register_intensive = register_intensive ,
@@ -2029,27 +2135,141 @@ def _reduction_configs(
20292135 min (rnumel , MAX_R0_BLOCK ),
20302136 register_intensive = register_intensive ,
20312137 )
2032- if inductor_meta .get ("max_autotune" ) or inductor_meta .get ("max_autotune_pointwise" ):
2033- pass # skip all these cases
2138+
2139+ outer_config = make_config (64 , 8 , register_intensive = register_intensive )
2140+ # TODO (paulzhan): Test heuristic on AMD and internal testing
2141+ # for correctness
2142+ if not torch .version .hip and not is_fbcode ():
2143+ outer_config = outer_config_opt ()
2144+
2145+ configs = []
2146+
2147+ if inductor_meta .get ("add_persistent_rblock" ) and loads_and_red <= 8 :
2148+ xnumel = max (4096 // rnumel , 1 )
2149+ c = make_config (
2150+ xnumel ,
2151+ rnumel ,
2152+ register_intensive = register_intensive ,
2153+ dynamic_scale_rblock = False ,
2154+ )
2155+ configs .append (c )
2156+
2157+ result_configs = []
2158+
2159+ # For 3d tiling, default to more autotuning initially
2160+ if "y" in size_hints :
2161+ pass
2162+ elif max_autotune_enabled :
2163+ pass # skip all these cases
20342164 elif reduction_hint == ReductionHint .INNER :
2035- return [contiguous_config ]
2165+ return configs + [contiguous_config ]
20362166 elif reduction_hint == ReductionHint .OUTER :
2037- return [outer_config ]
2167+ return configs + [outer_config ]
20382168 elif reduction_hint == ReductionHint .OUTER_TINY :
2039- return [tiny_config ]
2040- if disable_pointwise_autotuning (inductor_meta ):
2041- return [triton_config_reduction (size_hints , 32 , 128 )]
2042- return [
2043- contiguous_config ,
2044- outer_config ,
2045- tiny_config ,
2046- triton_config_reduction (size_hints , 64 , 64 ),
2047- triton_config_reduction (size_hints , 8 , 512 ),
2048- # halve the XBLOCK/Rn_BLOCK compared to outer_config
2049- # TODO: this may only be beneficial when each iteration of the reduction
2050- # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2051- triton_config_reduction (size_hints , 64 , 4 , num_warps = 8 ),
2052- ]
2169+ return configs + [tiny_config ]
2170+
2171+ # We continue here under the following conditions:
2172+ # - max_autotune_enabled is True
2173+ # - max_autotune_enabled is False and reduction_hint is NOT one of the above cases
2174+ result_configs = configs + [
2175+ contiguous_config ,
2176+ outer_config ,
2177+ tiny_config ,
2178+ make_config (64 , 64 ),
2179+ make_config (8 , 512 ),
2180+ # halve the XBLOCK/Rn_BLOCK compared to outer_config
2181+ # TODO: this may only be beneficial when each iteration of the reduction
2182+ # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2183+ make_config (64 , 4 , num_warps = 8 ),
2184+ ]
2185+
2186+ if torch .version .hip :
2187+ result_configs .extend (
2188+ [
2189+ make_config (1024 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 2 ),
2190+ make_config (512 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ),
2191+ ]
2192+ )
2193+
2194+ return result_configs
2195+
2196+
2197+ def match_target_block_product (
2198+ size_hints , tiling_scores , target_block_product , min_block_size = 1
2199+ ):
2200+ """
2201+ Distribute block sizes across dimensions according to tiling scores,
2202+ aiming to match a target product of block sizes.
2203+ """
2204+ total_score = sum (tiling_scores .values ())
2205+ if total_score == 0 :
2206+ # just assume even score with no minimum block size
2207+ min_block_size = 1
2208+ tiling_scores = dict .fromkeys (tiling_scores .keys (), target_block_product )
2209+
2210+ # First, give each coalescing dimension at least min_block_size
2211+ block_sizes = {}
2212+ relative_scores = {}
2213+ curr_block_product = 1
2214+
2215+ for dim , score in tiling_scores .items ():
2216+ if score == 0 :
2217+ block_sizes [dim ] = 1
2218+ continue
2219+
2220+ block_sizes [dim ] = min_block_size
2221+ curr_block_product *= min_block_size
2222+ relative_scores [dim ] = score / total_score
2223+
2224+ # Scale up dimensions by their relative scores until we reach the target
2225+ while curr_block_product < target_block_product and len (relative_scores ):
2226+ dim , score = max (relative_scores .items (), key = lambda item : item [1 ])
2227+
2228+ # Check if we've hit the max for this dimension
2229+ if (
2230+ block_sizes [dim ] >= TRITON_MAX_BLOCK [dim .capitalize ()]
2231+ or block_sizes [dim ] >= size_hints [dim ]
2232+ ):
2233+ del relative_scores [dim ]
2234+ continue
2235+
2236+ block_sizes [dim ] *= 2
2237+ relative_scores [dim ] /= 2
2238+ curr_block_product *= 2
2239+
2240+ return block_sizes
2241+
2242+
2243+ def adapt_config_for_tiling (
2244+ size_hints ,
2245+ tiling_scores ,
2246+ original_x ,
2247+ original_r ,
2248+ num_warps = None ,
2249+ num_stages = 1 ,
2250+ register_intensive = False ,
2251+ persistent_reduction = False ,
2252+ waves_per_eu = None ,
2253+ ) -> Config :
2254+ """
2255+ Create an adapted configuration based on tiling scores,
2256+ redistributing the same total block size (x * r) according to tiling scores.
2257+ """
2258+ assert all (s in tiling_scores for s in size_hints )
2259+ target_block_product = original_x * original_r
2260+ block_sizes = match_target_block_product (
2261+ size_hints , tiling_scores , target_block_product
2262+ )
2263+
2264+ return triton_config_tiled_reduction (
2265+ size_hints ,
2266+ block_sizes ["x" ],
2267+ block_sizes ["y" ],
2268+ block_sizes ["r0_" ],
2269+ num_stages = num_stages ,
2270+ register_intensive = register_intensive ,
2271+ waves_per_eu = waves_per_eu ,
2272+ )
20532273
20542274
20552275def reduction (
0 commit comments