Skip to content

Commit 6c845c6

Browse files
authored
[SWDEV-539215] - Autotune support for persistent reduction and no_x_dim removal (#2417)
We noticed persistent reduction kernels can be extremely poor performing https://ontrack-internal.amd.com/browse/SWDEV-539215 The root cause is that in certain size restrictions and kernels "no_x_dim" mode is enabled, which embeds static XBLOCK=1 into the kernel. This means tuning is not optimal. Removing this mode and enabling autotune we achieve 2x performance proving that new heuristics must be made. We will bring this into 2.7 for perf uplift, discussion is undergoing with upstream on removing no_x_dim, if there is no perf regression they are in agreement. Draft PR shows no perf loss on ROCm for any inductor benchmark pytorch#159048 Removing tests because no longer relevant.
1 parent f0aebdc commit 6c845c6

File tree

5 files changed

+29
-60
lines changed

5 files changed

+29
-60
lines changed

test/inductor/test_combo_kernels.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2):
296296

297297
self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
298298

299-
@requires_cuda
300-
def test_persistent_reduction_no_x_dim(self):
301-
def fn(x, y):
302-
return x.sum(1), y.sum(1)
303-
304-
inps = (
305-
torch.rand(16, 256, device="cuda"),
306-
torch.rand(32, 256, device="cuda"),
307-
)
308-
torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
309-
torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
310-
out_eager = fn(*inps)
311-
out_compiled = torch.compile(fn)(*inps)
312-
313-
self.assertEqual(out_eager, out_compiled)
314-
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
315-
316299

317300
@instantiate_parametrized_tests
318301
class ComboKernelDynamicShapesTests(TestCase):

test/inductor/test_torchinductor_strided_blocks.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -651,31 +651,6 @@ def test_2d_reduction_odd_shapes(
651651
# Check the code for multiple Rn_BLOCK's
652652
self._assert_reduction_ndims(code, 2)
653653

654-
def test_2d_reduction_no_x_dim(self):
655-
"""
656-
Tests a 2D reduction without an "x" dimension.
657-
"""
658-
# We need a size to get no x dim.
659-
view = self._discontiguous_tensor((2, 346), self.device)
660-
661-
# Expect 1 block pointer for the input.
662-
result, (code,) = run_and_compare(
663-
self,
664-
torch.prod,
665-
view,
666-
expected_num_block_pointers=1,
667-
expected_num_triton_kernels=1,
668-
config_patches=tiled_reduction_config,
669-
)
670-
671-
# Check that there's no X dimension in the signature.
672-
(signature_line,) = (
673-
line for line in code.splitlines() if line.startswith("def triton")
674-
)
675-
self.assertNotIn("BLOCK", signature_line)
676-
677-
# Check for 2 reduction dimensions in the body.
678-
self._assert_reduction_ndims(code, 2)
679654

680655
@parametrize(
681656
"size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback",

torch/_inductor/choices.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
109109
Heuristic to decide if we should drop the X dimension from a persistent reduction kernel.
110110
So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1.
111111
Strangely this is faster than a [1, RBLOCK] block in some cases.
112+
113+
ROCm branch change: Remove want_no_x_dim for persistent reduction.
114+
Inductor benchmarks show no perf advantage and simplifies autotune flow.
112115
"""
113-
return (
114-
features.get_reduction_hint() == ReductionHint.INNER
115-
and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256)
116-
)
116+
return False
117117

118118
@staticmethod
119119
def reduction_split_factor(

torch/_inductor/codegen/triton.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1713,13 +1713,10 @@ def should_use_persistent_reduction(self) -> bool:
17131713
)
17141714

17151715
def want_no_x_dim(self):
1716-
if (
1717-
self.persistent_reduction
1718-
and len(self.numels) == self.num_reduction_dims + 1
1719-
):
1720-
if self.fixed_config:
1721-
return self.fixed_config["XBLOCK"] == 1
1722-
return V.choices.want_no_x_dim(self.features)
1716+
"""
1717+
ROCm branch change: Remove want_no_x_dim for persistent reduction.
1718+
Inductor benchmarks show no perf advantage and simplifies autotune flow.
1719+
"""
17231720
return False
17241721

17251722
@property

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2049,25 +2049,39 @@ def _persistent_reduction_configs(
20492049
xnumel = size_hints["x"]
20502050
rnumel = get_total_reduction_numel(size_hints)
20512051

2052+
max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or (
2053+
inductor_meta.get("max_autotune")
2054+
or inductor_meta.get("max_autotune_pointwise")
2055+
)
2056+
20522057
configs = [
20532058
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
20542059
for xblock in (1, 8, 32, 128)
2055-
if xblock == 1 or (rnumel * xblock <= 4096 and xblock <= xnumel)
2060+
if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096))
20562061
]
20572062

20582063
# TODO(jansel): we should be able to improve these heuristics
2059-
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
2060-
configs = configs[:1]
2061-
elif reduction_hint == ReductionHint.OUTER:
2062-
configs = configs[-1:]
2063-
elif reduction_hint == ReductionHint.OUTER_TINY:
2064-
configs = [
2064+
if not max_autotune_enabled: # Don't filter if tuning enabled
2065+
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
2066+
configs = configs[:1]
2067+
elif reduction_hint == ReductionHint.OUTER:
2068+
configs = configs[-1:]
2069+
2070+
if reduction_hint == ReductionHint.OUTER_TINY:
2071+
tiny_configs = [
20652072
triton_config_reduction(
20662073
size_hints,
20672074
2 * (256 // rnumel) if rnumel <= 256 else 1,
20682075
rnumel,
20692076
)
20702077
]
2078+
if max_autotune_enabled:
2079+
for tconfig in tiny_configs:
2080+
if tconfig not in configs:
2081+
configs.append(tconfig)
2082+
else:
2083+
configs = tiny_configs
2084+
20712085
for c in configs:
20722086
# we don't need Rn_BLOCK for persistent reduction
20732087
for prefix in size_hints:

0 commit comments

Comments
 (0)