Skip to content

Commit 93dd572

Browse files
authored
red kernel config heuristics number of args fixed for 2D red kernels (#2768)
ditto
1 parent 3658645 commit 93dd572

File tree

1 file changed

+16
-16
lines changed

1 file changed

+16
-16
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2586,15 +2586,6 @@ def _persistent_reduction_configs(
25862586
)
25872587

25882588
# defer to more autotuning, initially
2589-
if "y" in size_hints:
2590-
pass
2591-
# TODO(jansel): we should be able to improve these heuristics
2592-
if not max_autotune_enabled: # Don't filter if tuning enabled
2593-
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
2594-
configs = configs[:1]
2595-
elif reduction_hint == ReductionHint.OUTER:
2596-
configs = configs[-1:]
2597-
25982589
tiny_configs = [
25992590
triton_config_reduction(
26002591
size_hints,
@@ -2603,13 +2594,22 @@ def _persistent_reduction_configs(
26032594
)
26042595
]
26052596

2606-
if max_autotune_enabled:
2607-
for conf in tiny_configs:
2608-
if conf not in configs:
2609-
configs.append(conf)
2610-
elif reduction_hint == ReductionHint.OUTER_TINY:
2611-
configs = tiny_configs
2612-
2597+
if "y" in size_hints:
2598+
pass
2599+
# TODO(jansel): we should be able to improve these heuristics
2600+
elif not max_autotune_enabled: # Don't filter if tuning enabled
2601+
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
2602+
configs = configs[:1]
2603+
elif reduction_hint == ReductionHint.OUTER:
2604+
configs = configs[-1:]
2605+
elif reduction_hint == ReductionHint.OUTER_TINY:
2606+
configs = tiny_configs
2607+
else:
2608+
if max_autotune_enabled:
2609+
for conf in tiny_configs:
2610+
if conf not in configs:
2611+
configs.append(conf)
2612+
26132613
for c in configs:
26142614
# we don't need Rn_BLOCK for persistent reduction
26152615
for prefix in size_hints:

0 commit comments

Comments
 (0)