Skip to content

Commit 13d0455

Browse files
committed
Expand persistent reduction tuning space
(cherry picked from commit 7a77bc4)
1 parent 3dd863d commit 13d0455

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2739,15 +2739,15 @@ def _persistent_reduction_configs(
27392739
if "y" not in size_hints:
27402740
configs = [
27412741
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
2742-
for xblock in (1, 8, 32, 128)
2742+
for xblock in (1, 4, 8, 16, 32, 64, 128, 256, 512)
27432743
if xblock == 1
27442744
or (xblock <= xnumel and rnumel * xblock <= 4096)
27452745
]
27462746
else:
27472747
configs = []
27482748
assert "tiling_scores" in inductor_meta
27492749
x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")}
2750-
for target_block_size in (1, 8, 32, 64, 128):
2750+
for target_block_size in (1, 4, 8, 16, 32, 64, 128, 256, 512):
27512751
if target_block_size * rnumel > 4096:
27522752
continue
27532753

@@ -2782,6 +2782,22 @@ def _persistent_reduction_configs(
27822782
for conf in tiny_configs:
27832783
if conf not in configs:
27842784
configs.append(conf)
2785+
2786+
# Expand configs to try additional warps
2787+
expanded_configs = []
2788+
for conf in configs:
2789+
num_warps = conf.num_warps
2790+
max_warps = 8 if torch.version.hip else 16
2791+
small_conf = copy.deepcopy(conf)
2792+
large_conf = copy.deepcopy(conf)
2793+
small_conf.num_warps = max(small_conf.num_warps // 2, 1)
2794+
large_conf.num_warps = min(large_conf.num_warps * 2, max_warps)
2795+
expanded_configs.append(conf)
2796+
expanded_configs.append(small_conf)
2797+
expanded_configs.append(large_conf)
2798+
2799+
configs = expanded_configs
2800+
27852801
elif reduction_hint == ReductionHint.OUTER_TINY:
27862802
configs = tiny_configs
27872803

0 commit comments

Comments
 (0)