diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 440aecf5163ee..4957ba77a7648 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2417,6 +2417,15 @@ def pointwise( ), *hinted_configs, ] + # Additional reduction configs appended for ROCm builds + if torch.version.hip: + configs.append(triton_config_with_settings( + size_hints, + 2048, + num_warps=8, + num_stages=2, + waves_per_eu=1 + )) # 20% improvement if len(size_hints) == 2: if ( disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE