@@ -765,7 +765,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs):
765765 # for some (complicated) custom Triton kernels, a register-spilling
766766 # config may yield the best latency.
767767 if not self .custom_kernel and launcher .n_spills > self .inductor_meta .get (
768- "spill_threshold" , 16
768+ "spill_threshold" , 32
769769 ):
770770 log .debug (
771771 "Skip config %s because of register spilling: %d" ,
@@ -2198,7 +2198,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:
21982198
21992199
22002200def triton_config_tiled_reduction (
2201- size_hints , x , y , r , num_stages = 1 , register_intensive = False
2201+ size_hints , x , y , r , num_stages = 1 , register_intensive = False , waves_per_eu = None
22022202):
22032203 """
22042204 Construct a tile reduction triton config with some adjustment
@@ -2235,7 +2235,13 @@ def total_numel() -> int:
22352235 )
22362236 check_config (cfg , xnumel = size_hints ["x" ], ynumel = size_hints ["y" ])
22372237 check_max_block (cfg )
2238- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
2238+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
2239+
2240+ if torch .version .hip :
2241+ if waves_per_eu is not None :
2242+ config .kwargs ["waves_per_eu" ] = waves_per_eu
2243+
2244+ return config
22392245
22402246
22412247def pointwise (
@@ -2279,17 +2285,26 @@ def pointwise(
22792285 triton_config_with_settings (
22802286 size_hints , bs // 2 , num_elements_per_warp = 64
22812287 ),
2282- # triton_config_with_settings(
2283- # size_hints, 8192, num_warps=8, num_stages=1, matrix_instr=0, waves_per_eu=2
2284- # ),
22852288 triton_config_with_settings (
22862289 size_hints , TRITON_MAX_BLOCK ["X" ], waves_per_eu = 2
22872290 ),
2291+ triton_config_with_settings (
2292+ size_hints , 4096 # wrt: better than the max_block for some kernel
2293+ ),
22882294 * hinted_configs ,
22892295 ]
2296+ # Additional reduction configs appended for ROCm builds
2297+ if torch .version .hip :
2298+ configs .append (triton_config_with_settings (
2299+ size_hints ,
2300+ 2048 ,
2301+ num_warps = 8 ,
2302+ num_stages = 2 ,
2303+ waves_per_eu = 1
2304+ )) # 20% improvement
22902305 if len (size_hints ) == 2 :
22912306 if (
2292- disable_pointwise_autotuning (inductor_meta ) or tile_hint == TileHint .SQUARE
2307+ disable_pointwise_autotuning (inductor_meta ) # or tile_hint == TileHint.SQUARE
22932308 ) and not (
22942309 inductor_meta .get ("max_autotune" )
22952310 or inductor_meta .get ("max_autotune_pointwise" )
@@ -2298,9 +2313,13 @@ def pointwise(
22982313 else :
22992314 configs = [
23002315 triton_config_with_settings (size_hints , 32 , 32 ),
2316+ triton_config_with_settings (size_hints , 64 , 32 ), # wrt: better for some kernels
23012317 triton_config_with_settings (size_hints , 64 , 64 ), # ~8% better for fp16
23022318 triton_config_with_settings (size_hints , 256 , 16 ),
23032319 triton_config_with_settings (size_hints , 16 , 256 ),
2320+ triton_config_with_settings (size_hints , 128 , 16 ), # wrt: +10% for some kernels
2321+ triton_config_with_settings (size_hints , 128 , 32 ), # wrt: ..additional 10% more
2322+ triton_config_with_settings (size_hints , 32 , 512 ), # wrt: +30% for some kernels
23042323 triton_config_with_settings (size_hints , bs , 1 ),
23052324 triton_config_with_settings (size_hints , 1 , bs ),
23062325 * hinted_configs ,
@@ -2340,6 +2359,12 @@ def _reduction_configs(
23402359 # Convert reductions to 1D, to simplify heuristics.
23412360 rnumel = get_total_reduction_numel (size_hints )
23422361
2362+ # Is max autotune enabled
2363+ max_autotune_enabled = not disable_pointwise_autotuning (inductor_meta ) or (
2364+ inductor_meta .get ("max_autotune" )
2365+ or inductor_meta .get ("max_autotune_pointwise" )
2366+ )
2367+
23432368 register_intensive = False
23442369 MAX_R0_BLOCK = 2048
23452370 if (
@@ -2362,7 +2387,7 @@ def _reduction_configs(
23622387 MAX_R0_BLOCK = 1024
23632388 register_intensive = True
23642389
2365- def make_config (x , r , num_warps = None , num_stages = 1 , register_intensive = False ):
2390+ def make_config (x , r , num_warps = None , num_stages = 1 , register_intensive = False , waves_per_eu = None ):
23662391 # For 3D case with tiling scores, create an adapted version
23672392 if "y" in size_hints :
23682393 assert "tiling_scores" in inductor_meta
@@ -2374,6 +2399,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
23742399 num_warps = num_warps ,
23752400 num_stages = num_stages ,
23762401 register_intensive = register_intensive ,
2402+ waves_per_eu = waves_per_eu
23772403 )
23782404 else :
23792405 # For other cases, use the original function
@@ -2384,6 +2410,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
23842410 num_warps = num_warps ,
23852411 num_stages = num_stages ,
23862412 register_intensive = register_intensive ,
2413+ waves_per_eu = waves_per_eu
23872414 )
23882415
23892416 contiguous_config = make_config (
@@ -2397,32 +2424,39 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
23972424 min (rnumel , MAX_R0_BLOCK ),
23982425 register_intensive = register_intensive ,
23992426 )
2400- # For 3d tiling, default to more autotuning initially
2401- if "y" in size_hints :
2402- pass
2403- elif inductor_meta .get ("max_autotune" ) or inductor_meta .get (
2404- "max_autotune_pointwise"
2405- ):
2406- pass # skip all these cases
2407- elif reduction_hint == ReductionHint .INNER :
2408- return [contiguous_config ]
2409- elif reduction_hint == ReductionHint .OUTER :
2410- return [outer_config ]
2411- elif reduction_hint == ReductionHint .OUTER_TINY :
2412- return [tiny_config ]
2413- if disable_pointwise_autotuning (inductor_meta ):
2414- return [make_config (32 , 128 )]
2415- return [
2416- contiguous_config ,
2417- outer_config ,
2418- tiny_config ,
2419- make_config (64 , 64 ),
2420- make_config (8 , 512 ),
2421- # halve the XBLOCK/Rn_BLOCK compared to outer_config
2422- # TODO: this may only be beneficial when each iteration of the reduction
2423- # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2424- make_config (64 , 4 , num_warps = 8 ),
2425- ]
2427+
2428+ result_configs = []
2429+
2430+ if not (max_autotune_enabled or "y" in size_hints ):
2431+ if reduction_hint == ReductionHint .INNER :
2432+ result_configs = [contiguous_config ]
2433+ elif reduction_hint == ReductionHint .OUTER :
2434+ result_configs = [outer_config ]
2435+ elif reduction_hint == ReductionHint .OUTER_TINY :
2436+ result_configs = [tiny_config ]
2437+ else :
2438+ result_configs = [make_config (32 , 128 )]
2439+ else :
2440+ result_configs = [
2441+ contiguous_config ,
2442+ outer_config ,
2443+ tiny_config ,
2444+ make_config (64 , 64 ),
2445+ make_config (8 , 512 ),
2446+ # halve the XBLOCK/Rn_BLOCK compared to outer_config
2447+ # TODO: this may only be beneficial when each iteration of the reduction
2448+ # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2449+ make_config (64 , 4 , num_warps = 8 ),
2450+ ]
2451+
2452+ # Add ROCm-specific configs when autotuning
2453+ if torch .version .hip :
2454+ result_configs .extend ([
2455+ make_config (1024 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 2 ),
2456+ make_config (512 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 )
2457+ ])
2458+
2459+ return result_configs
24262460
24272461
24282462def match_target_block_product (
@@ -2480,6 +2514,7 @@ def adapt_config_for_tiling(
24802514 num_stages = 1 ,
24812515 register_intensive = False ,
24822516 persistent_reduction = False ,
2517+ waves_per_eu = None
24832518) -> Config :
24842519 """
24852520 Create an adapted configuration based on tiling scores,
@@ -2498,6 +2533,7 @@ def adapt_config_for_tiling(
24982533 block_sizes ["r0_" ],
24992534 num_stages = num_stages ,
25002535 register_intensive = register_intensive ,
2536+ waves_per_eu = waves_per_eu
25012537 )
25022538
25032539
@@ -2608,15 +2644,15 @@ def _persistent_reduction_configs(
26082644 if "y" not in size_hints :
26092645 configs = [
26102646 triton_config_reduction (size_hints , xblock , rnumel , register_intensive = True )
2611- for xblock in (1 , 8 , 32 , 128 )
2647+ for xblock in (1 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 )
26122648 if xblock == 1
2613- or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel )
2649+ or (xblock <= xnumel and rnumel * xblock <= 4096 )
26142650 ]
26152651 else :
26162652 configs = []
26172653 assert "tiling_scores" in inductor_meta
26182654 x_y_scores = {dim : inductor_meta ["tiling_scores" ][dim ] for dim in ("x" , "y" )}
2619- for target_block_size in (1 , 8 , 32 , 64 , 128 ):
2655+ for target_block_size in (1 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 ):
26202656 if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL :
26212657 continue
26222658
@@ -2651,6 +2687,22 @@ def _persistent_reduction_configs(
26512687 for conf in tiny_configs :
26522688 if conf not in configs :
26532689 configs .append (conf )
2690+
2691+ # Expand configs to try additional warps
2692+ expanded_configs = []
2693+ for conf in configs :
2694+ num_warps = conf .num_warps
2695+ max_warps = 8 if torch .version .hip else 16
2696+ small_conf = copy .deepcopy (conf )
2697+ large_conf = copy .deepcopy (conf )
2698+ small_conf .num_warps = max (small_conf .num_warps // 2 , 1 )
2699+ large_conf .num_warps = min (large_conf .num_warps * 2 , max_warps )
2700+ expanded_configs .append (conf )
2701+ expanded_configs .append (small_conf )
2702+ expanded_configs .append (large_conf )
2703+
2704+ configs = expanded_configs
2705+
26542706 elif reduction_hint == ReductionHint .OUTER_TINY :
26552707 configs = tiny_configs
26562708
0 commit comments