@@ -838,7 +838,7 @@ def bench(self, launcher, *args, with_profiler=False, **kwargs):
838838 # for some (complicated) custom Triton kernels, a register-spilling
839839 # config may yield the best latency.
840840 if not self .custom_kernel and launcher .n_spills > self .inductor_meta .get (
841- "spill_threshold" , 16
841+ "spill_threshold" , 32 if torch . version . hip else 16
842842 ):
843843 log .debug (
844844 "Skip config %s because of register spilling: %d" ,
@@ -2312,6 +2312,7 @@ def triton_config_reduction(
23122312 num_stages = 1 ,
23132313 num_warps = None ,
23142314 register_intensive = False ,
2315+ waves_per_eu = None ,
23152316 dynamic_scale_rblock = True ,
23162317) -> Config :
23172318 """
@@ -2356,13 +2357,19 @@ def total_numel() -> int:
23562357 cfg = _get_config ({"x" : x , ** rnumels })
23572358 check_max_block (cfg )
23582359 check_config (cfg , xnumel = size_hints ["x" ])
2359- return InductorConfig (
2360+ config = InductorConfig (
23602361 cfg ,
23612362 num_warps = num_warps ,
23622363 num_stages = num_stages ,
23632364 dynamic_scale_rblock = dynamic_scale_rblock ,
23642365 )
23652366
2367+ if torch .version .hip :
2368+ if waves_per_eu is not None :
2369+ config .kwargs ["waves_per_eu" ] = waves_per_eu
2370+
2371+ return config
2372+
23662373
23672374def _get_config (numels : dict [str , int ]) -> dict [str , int ]:
23682375 """
@@ -2373,7 +2380,7 @@ def _get_config(numels: dict[str, int]) -> dict[str, int]:
23732380
23742381
23752382def triton_config_tiled_reduction (
2376- size_hints , x , y , r , num_stages = 1 , register_intensive = False
2383+ size_hints , x , y , r , num_stages = 1 , register_intensive = False , waves_per_eu = None
23772384):
23782385 """
23792386 Construct a tile reduction triton config with some adjustment
@@ -2410,7 +2417,11 @@ def total_numel() -> int:
24102417 )
24112418 check_config (cfg , xnumel = size_hints ["x" ], ynumel = size_hints ["y" ])
24122419 check_max_block (cfg )
2413- return Config (cfg , num_warps = num_warps , num_stages = num_stages )
2420+ config = Config (cfg , num_warps = num_warps , num_stages = num_stages )
2421+ if torch .version .hip :
2422+ if waves_per_eu is not None :
2423+ config .kwargs ["waves_per_eu" ] = waves_per_eu
2424+ return config
24142425
24152426
24162427def _maybe_filter_configs_for_tma_restrictions (inductor_meta , configs : list [Config ]):
@@ -2584,6 +2595,11 @@ def _reduction_configs(
25842595 # Convert reductions to 1D, to simplify heuristics.
25852596 rnumel = get_total_reduction_numel (size_hints )
25862597
2598+ # Is max autotune enabled
2599+ max_autotune_enabled = inductor_meta .get ("max_autotune" ) or inductor_meta .get (
2600+ "max_autotune_pointwise"
2601+ )
2602+
25872603 register_intensive = False
25882604 MAX_R0_BLOCK = 2048
25892605 loads_and_red = inductor_meta .get ("num_load" , 0 ) + inductor_meta .get (
@@ -2612,6 +2628,7 @@ def make_config(
26122628 num_stages = 1 ,
26132629 register_intensive = False ,
26142630 dynamic_scale_rblock = True ,
2631+ waves_per_eu = None ,
26152632 ):
26162633 # For 3D case with tiling scores, create an adapted version
26172634 if "y" in size_hints :
@@ -2624,6 +2641,7 @@ def make_config(
26242641 num_warps = num_warps ,
26252642 num_stages = num_stages ,
26262643 register_intensive = register_intensive ,
2644+ waves_per_eu = waves_per_eu ,
26272645 )
26282646 else :
26292647 # For other cases, use the original function
@@ -2634,6 +2652,7 @@ def make_config(
26342652 num_warps = num_warps ,
26352653 num_stages = num_stages ,
26362654 register_intensive = register_intensive ,
2655+ waves_per_eu = waves_per_eu ,
26372656 dynamic_scale_rblock = dynamic_scale_rblock ,
26382657 )
26392658
@@ -2714,33 +2733,40 @@ def outer_config_opt():
27142733 )
27152734 configs .append (c )
27162735
2736+ result_configs = []
2737+
27172738 # For 3d tiling, default to more autotuning initially
2718- if "y" in size_hints :
2719- pass
2720- elif inductor_meta .get ("max_autotune" ) or inductor_meta .get (
2721- "max_autotune_pointwise"
2722- ):
2723- pass # skip all these cases
2724- elif reduction_hint == ReductionHint .INNER :
2725- return configs + [contiguous_config ]
2726- elif reduction_hint == ReductionHint .OUTER :
2727- return configs + [outer_config ]
2728- elif reduction_hint == ReductionHint .OUTER_TINY :
2729- return configs + [tiny_config ]
2730- if disable_pointwise_autotuning (inductor_meta ):
2731- return configs + [make_config (32 , 128 )]
2732-
2733- return configs + [
2734- contiguous_config ,
2735- outer_config ,
2736- tiny_config ,
2737- make_config (64 , 64 ),
2738- make_config (8 , 512 ),
2739- # halve the XBLOCK/Rn_BLOCK compared to outer_config
2740- # TODO: this may only be beneficial when each iteration of the reduction
2741- # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2742- make_config (64 , 4 , num_warps = 8 ),
2743- ]
2739+ if not (max_autotune_enabled or "y" in size_hints ):
2740+ if reduction_hint == ReductionHint .INNER :
2741+ result_configs = configs + [contiguous_config ]
2742+ elif reduction_hint == ReductionHint .OUTER :
2743+ result_configs = configs + [outer_config ]
2744+ elif reduction_hint == ReductionHint .OUTER_TINY :
2745+ result_configs = configs + [tiny_config ]
2746+ else :
2747+ result_configs = configs + [make_config (32 , 128 )]
2748+ else :
2749+ result_configs = configs + [
2750+ contiguous_config ,
2751+ outer_config ,
2752+ tiny_config ,
2753+ make_config (64 , 64 ),
2754+ make_config (8 , 512 ),
2755+ # halve the XBLOCK/Rn_BLOCK compared to outer_config
2756+ # TODO: this may only be beneficial when each iteration of the reduction
2757+ # is quite heavy. E.g. https://gist.github.com/shunting314/189a8ef69f90db9d614a823385147a72
2758+ make_config (64 , 4 , num_warps = 8 ),
2759+ ]
2760+
2761+ if torch .version .hip :
2762+ result_configs .extend (
2763+ [
2764+ make_config (1024 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 2 ),
2765+ make_config (512 , 8 , num_warps = 4 , num_stages = 1 , waves_per_eu = 1 ),
2766+ ]
2767+ )
2768+
2769+ return result_configs
27442770
27452771
27462772def match_target_block_product (
@@ -2798,6 +2824,7 @@ def adapt_config_for_tiling(
27982824 num_stages = 1 ,
27992825 register_intensive = False ,
28002826 persistent_reduction = False ,
2827+ waves_per_eu = None ,
28012828) -> Config :
28022829 """
28032830 Create an adapted configuration based on tiling scores,
@@ -2816,6 +2843,7 @@ def adapt_config_for_tiling(
28162843 block_sizes ["r0_" ],
28172844 num_stages = num_stages ,
28182845 register_intensive = register_intensive ,
2846+ waves_per_eu = waves_per_eu ,
28192847 )
28202848
28212849
0 commit comments