@@ -2282,7 +2282,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:
22822282
22832283
22842284def triton_config_tiled_reduction (
2285- size_hints , x , y , r , num_stages = 1 , register_intensive = False
2285+ size_hints , x , y , r , num_stages = 1 , register_intensive = False , waves_per_eu = None
22862286):
22872287 """
22882288 Construct a tile reduction triton config with some adjustment
@@ -2319,7 +2319,11 @@ def total_numel() -> int:
23192319 )
23202320 check_config (cfg , xnumel = size_hints ["x" ], ynumel = size_hints ["y" ])
23212321 check_max_block (cfg )
2322- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
2322+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
2323+ if torch .version .hip :
2324+ if waves_per_eu is not None :
2325+ config .kwargs ["waves_per_eu" ] = waves_per_eu
2326+ return config
23232327
23242328
23252329def _maybe_filter_configs_for_tma_restrictions (inductor_meta , configs : list [Config ]):
@@ -2469,6 +2473,9 @@ def _reduction_configs(
24692473 # Convert reductions to 1D, to simplify heuristics.
24702474 rnumel = get_total_reduction_numel (size_hints )
24712475
2476+ # Is max autotune enabled
2477+ max_autotune = inductor_meta .get ("max_autotune" ) or inductor_meta .get ("max_autotune_pointwise" )
2478+
24722479 register_intensive = False
24732480 MAX_R0_BLOCK = 2048
24742481 if (
@@ -2491,7 +2498,7 @@ def _reduction_configs(
24912498 MAX_R0_BLOCK = 1024
24922499 register_intensive = True
24932500
2494- def make_config (x , r , num_warps = None , num_stages = 1 , register_intensive = False ):
2501+ def make_config (x , r , num_warps = None , num_stages = 1 , register_intensive = False , waves_per_eu = None ):
24952502 # For 3D case with tiling scores, create an adapted version
24962503 if "y" in size_hints :
24972504 assert "tiling_scores" in inductor_meta
@@ -2503,6 +2510,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
25032510 num_warps = num_warps ,
25042511 num_stages = num_stages ,
25052512 register_intensive = register_intensive ,
2513+ waves_per_eu = waves_per_eu
25062514 )
25072515 else :
25082516 # For other cases, use the original function
@@ -2513,6 +2521,7 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
25132521 num_warps = num_warps ,
25142522 num_stages = num_stages ,
25152523 register_intensive = register_intensive ,
2524+ waves_per_eu = waves_per_eu
25162525 )
25172526
25182527 contiguous_config = make_config (
@@ -2526,54 +2535,38 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
25262535 min (rnumel , MAX_R0_BLOCK ),
25272536 register_intensive = register_intensive ,
25282537 )
2529- # For 3d tiling, default to more autotuning initially
2530- if "y" in size_hints :
2531- pass
2532- elif inductor_meta .get ("max_autotune" ) or inductor_meta .get (
2533- "max_autotune_pointwise"
2534- ):
25352538
2536- result_configs = []
2537-
2538- # Extra ROCm tuning
2539- if torch .version .hip :
2540- result_configs .append (triton_config_reduction (
2541- size_hints ,
2542- 1024 ,
2543- 8 ,
2544- num_warps = 4 ,
2545- num_stages = 1 ,
2546- waves_per_eu = 2
2547- ))
2548- result_configs .append (triton_config_reduction (
2549- size_hints ,
2550- 512 ,
2551- 8 ,
2552- num_warps = 4 ,
2553- num_stages = 1 ,
2554- waves_per_eu = 1
2555- ))
2556-
2557- elif reduction_hint == ReductionHint .INNER :
2558- result_configs = [contiguous_config ]
2559- elif reduction_hint == ReductionHint .OUTER :
2560- result_configs = [outer_config ]
2561- elif reduction_hint == ReductionHint .OUTER_TINY :
2562- result_configs = [tiny_config ]
2563- if disable_pointwise_autotuning (inductor_meta ):
2564- result_configs = [make_config (32 , 128 )]
2565- result_configs = [
2566- contiguous_config ,
2567- outer_config ,
2568- tiny_config ,
2569- make_config (64 , 64 ),
2570- make_config (8 , 512 ),
2571- # halve the XBLOCK/Rn_BLOCK compared to outer_config
2572- # TODO: this may only be beneficial when each iteration of the reduction
2573- # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2574- make_config (64 , 4 , num_warps = 8 ),
2575- ]
2539+ result_configs = []
2540+
2541+ if not (max_autotune or "y" in size_hints ):
2542+ if reduction_hint == ReductionHint .INNER :
2543+ result_configs = [contiguous_config ]
2544+ elif reduction_hint == ReductionHint .OUTER :
2545+ result_configs = [outer_config ]
2546+ elif reduction_hint == ReductionHint .OUTER_TINY :
2547+ result_configs = [tiny_config ]
2548+ else :
2549+ result_configs = [make_config (32 , 128 )]
2550+ else :
2551+ result_configs = [
2552+ contiguous_config ,
2553+ outer_config ,
2554+ tiny_config ,
2555+ make_config (64 , 64 ),
2556+ make_config (8 , 512 ),
2557+ # halve the XBLOCK/Rn_BLOCK compared to outer_config
2558+ # TODO: this may only be beneficial when each iteration of the reduction
2559+ # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2560+ make_config (64 , 4 , num_warps = 8 ),
2561+ ]
25762562
2563+ # Add ROCm-specific configs when autotuning
2564+ if torch .version .hip :
2565+ result_configs .extend ([
2566+ make_config (1024 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 2 ),
2567+ make_config (512 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 )
2568+ ])
2569+
25772570 return result_configs
25782571
25792572
@@ -2632,6 +2625,7 @@ def adapt_config_for_tiling(
26322625 num_stages = 1 ,
26332626 register_intensive = False ,
26342627 persistent_reduction = False ,
2628+ waves_per_eu = None
26352629) -> Config :
26362630 """
26372631 Create an adapted configuration based on tiling scores,
@@ -2650,6 +2644,7 @@ def adapt_config_for_tiling(
26502644 block_sizes ["r0_" ],
26512645 num_stages = num_stages ,
26522646 register_intensive = register_intensive ,
2647+ waves_per_eu = waves_per_eu
26532648 )
26542649
26552650
0 commit comments