Skip to content

Commit 768361e

Browse files
PaulZhang12pytorchmergebot
authored andcommitted
Add less warps config to inner reductions (pytorch#162447)
Add less warps to ensure proper vectorization + memory coalescing for inner reductions, prefer more work per thread <img width="1717" height="731" alt="Screenshot 2025-09-17 at 10 03 25 AM" src="https://github.com/user-attachments/assets/7b1f4a30-62f2-4bee-bb9c-122501bde63e" /> Pull Request resolved: pytorch#162447 Approved by: https://github.com/v0i0, https://github.com/eellison, https://github.com/shunting314
1 parent 9341ede commit 768361e

File tree

1 file changed

+18
-3
lines changed

1 file changed

+18
-3
lines changed

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2333,6 +2333,7 @@ def triton_config_reduction(
23332333
num_warps=None,
23342334
register_intensive=False,
23352335
dynamic_scale_rblock=True,
2336+
reduction_hint=None,
23362337
) -> Config:
23372338
"""
23382339
Construct a reduction triton config with some adjustment heuristics
@@ -2360,7 +2361,12 @@ def total_numel() -> int:
23602361
rnumels[prefix] *= 2
23612362

23622363
if num_warps is None:
2363-
num_warps = total_numel() // 128
2364+
if reduction_hint == ReductionHint.INNER:
2365+
# r is contiguous, so ensure that each thread has 8 elements for
2366+
# vectorized loads, assuming bf16/fp16
2367+
num_warps = r // (32 * 8)
2368+
else:
2369+
num_warps = total_numel() // 128
23642370

23652371
max_num_warps = 16 if r <= 8192 else 32
23662372
num_warps = _num_warps(
@@ -2630,6 +2636,7 @@ def make_config(
26302636
num_stages=num_stages,
26312637
register_intensive=register_intensive,
26322638
dynamic_scale_rblock=dynamic_scale_rblock,
2639+
reduction_hint=reduction_hint,
26332640
)
26342641

26352642
def outer_config_opt():
@@ -2681,7 +2688,7 @@ def outer_config_opt():
26812688
)
26822689

26832690
contiguous_config = make_config(
2684-
1,
2691+
1 if rnumel > 2048 else 2, # 1024 or less is persistent
26852692
min(rnumel, MAX_R0_BLOCK),
26862693
register_intensive=register_intensive,
26872694
)
@@ -2911,7 +2918,13 @@ def _persistent_reduction_configs(
29112918

29122919
if "y" not in size_hints:
29132920
configs = [
2914-
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
2921+
triton_config_reduction(
2922+
size_hints,
2923+
xblock,
2924+
rnumel,
2925+
register_intensive=True,
2926+
reduction_hint=reduction_hint,
2927+
)
29152928
for xblock in (1, 8, 32, 128)
29162929
if xblock == 1
29172930
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
@@ -2954,6 +2967,7 @@ def _persistent_reduction_configs(
29542967
x_block,
29552968
rnumel,
29562969
register_intensive=True,
2970+
reduction_hint=reduction_hint,
29572971
)
29582972
]
29592973

@@ -2965,6 +2979,7 @@ def _persistent_reduction_configs(
29652979
size_hints,
29662980
2 * (256 // rnumel) if rnumel <= 256 else 1,
29672981
rnumel,
2982+
reduction_hint=reduction_hint,
29682983
)
29692984
]
29702985
for c in configs:

0 commit comments

Comments
 (0)