Skip to content

Commit 4e630f0

Browse files
Revert "[Inductor] Update Outer Reduction Heuristic (pytorch#159093)"
This reverts commit ca9fe01. Reverted pytorch#159093 on behalf of https://github.com/PaulZhang12 due to Addressing internal implications then relanding ([comment](pytorch#159093 (comment)))
1 parent cddcaa1 commit 4e630f0

File tree

1 file changed

+3
-62
lines changed

1 file changed

+3
-62
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 3 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,7 +2478,7 @@ def pointwise(
24782478

24792479

24802480
def _reduction_configs(
2481-
*, size_hints: dict[str, int], inductor_meta: dict[str, Any], num_dynamic=0
2481+
*, size_hints: dict[str, int], inductor_meta: dict[str, Any]
24822482
) -> list[Config]:
24832483
reduction_hint = inductor_meta.get("reduction_hint", None)
24842484

@@ -2531,68 +2531,17 @@ def make_config(x, r, num_warps=None, num_stages=1, register_intensive=False):
25312531
register_intensive=register_intensive,
25322532
)
25332533

2534-
def outer_config_opt():
2535-
# Default to 64 for vectorized loads
2536-
max_x_block, x_block = 256, 64
2537-
load_factor = inductor_meta.get("num_load", 0)
2538-
x = size_hints["x"]
2539-
num_warps = None
2540-
2541-
# Try to use all SMs with small x
2542-
if x <= 1024:
2543-
x_block = max(min(x // 128, 8), 2)
2544-
outer_r_block = min(rnumel, 64)
2545-
# Lower bound x = 1024, 1024 // 16 = 128 around # of SMs
2546-
elif x // 4096 <= 8:
2547-
x_block = 16
2548-
outer_r_block = 512 // x_block
2549-
elif num_dynamic > 1:
2550-
# Lots of compute with multiple dynamic shape per loop iteration
2551-
# Larger RBLOCK minimizes loop iteration
2552-
outer_r_block = max(min((rnumel // 64), 64), 8)
2553-
elif num_dynamic == 1:
2554-
# Dynamic shapes introduce a lot register pressure for indexing
2555-
outer_r_block = (
2556-
1
2557-
if load_factor >= 3
2558-
else min(next_power_of_2(max(rnumel, 128) // 128), 8)
2559-
)
2560-
else:
2561-
x_block = max(min(max_x_block, next_power_of_2(x // 4096)), x_block)
2562-
if load_factor < 4 or rnumel <= 128:
2563-
outer_r_block = 512 // x_block
2564-
else:
2565-
# Heavier reductions contain a lot more overhead per loop iteration
2566-
# We minimize the overhead by enlarging r block
2567-
if rnumel >= 2048:
2568-
outer_r_block = 64
2569-
else:
2570-
outer_r_block = 32
2571-
x_block = min(x_block, 32)
2572-
num_warps = 4
2573-
2574-
# Set register intensive to true by default as we try to maximize tiles with heuristic
2575-
return make_config(
2576-
x_block,
2577-
outer_r_block,
2578-
num_warps=num_warps,
2579-
register_intensive=register_intensive,
2580-
)
2581-
25822534
contiguous_config = make_config(
25832535
1,
25842536
min(rnumel, MAX_R0_BLOCK),
25852537
register_intensive=register_intensive,
25862538
)
2539+
outer_config = make_config(64, 8, register_intensive=register_intensive)
25872540
tiny_config = make_config(
25882541
2 * (256 // rnumel) if rnumel <= 256 else 1,
25892542
min(rnumel, MAX_R0_BLOCK),
25902543
register_intensive=register_intensive,
25912544
)
2592-
2593-
outer_config = make_config(64, 8, register_intensive=register_intensive)
2594-
if not torch.version.hip:
2595-
outer_config = outer_config_opt()
25962545
# For 3d tiling, default to more autotuning initially
25972546
if "y" in size_hints:
25982547
pass
@@ -2712,15 +2661,7 @@ def reduction(
27122661

27132662
assert triton_meta is not None
27142663

2715-
num_dynamic = 0
2716-
for k in triton_meta["signature"].keys():
2717-
if "ks" in k:
2718-
num_dynamic += 1
2719-
2720-
configs = _reduction_configs(
2721-
size_hints=size_hints, inductor_meta=inductor_meta, num_dynamic=num_dynamic
2722-
)
2723-
2664+
configs = _reduction_configs(size_hints=size_hints, inductor_meta=inductor_meta)
27242665
configs = _maybe_filter_configs_for_tma_restrictions(inductor_meta, configs)
27252666
return cached_autotune(
27262667
size_hints,

0 commit comments

Comments
 (0)