diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 951977af602e5..6abe93672c465 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2507,6 +2507,15 @@ def pointwise( ), *hinted_configs, ] + # Additional reduction configs appended for ROCm builds + if torch.version.hip: + configs.append(triton_config_with_settings( + size_hints, + 2048, + num_warps=8, + num_stages=2, + waves_per_eu=1 + )) # 20% improvement if len(size_hints) == 2: if ( disable_pointwise_autotuning(inductor_meta) # or tile_hint == TileHint.SQUARE