diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 2ea6a2d467a67..a79d6cd41b7cc 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2595,20 +2595,20 @@ def _persistent_reduction_configs( elif reduction_hint == ReductionHint.OUTER: configs = configs[-1:] - if reduction_hint == ReductionHint.OUTER_TINY: - tiny_configs = [ - triton_config_reduction( - size_hints, - 2 * (256 // rnumel) if rnumel <= 256 else 1, - rnumel, - ) - ] - if max_autotune_enabled: - for tconfig in tiny_configs: - if tconfig not in configs: - configs.append(tconfig) - else: - configs = tiny_configs + tiny_configs = [ + triton_config_reduction( + size_hints, + 2 * (256 // rnumel) if rnumel <= 256 else 1, + rnumel, + ) + ] + + if max_autotune_enabled: + for conf in tiny_configs: + if conf not in configs: + configs.append(conf) + elif reduction_hint == ReductionHint.OUTER_TINY: + configs = tiny_configs for c in configs: # we don't need Rn_BLOCK for persistent reduction