Skip to content

Commit eb47158

Browse files
okakarpajataylo
andauthored
[AUTOGENERATED] [release/2.8] [SWDEV-539215] - Autotune support for persistent reduction and no_x_dim removal (#2454)
Cherry-pick of #2417 Need to resolve conflicts --------- Co-authored-by: Jack Taylor <[email protected]>
1 parent 2067a0b commit eb47158

File tree

5 files changed

+27
-59
lines changed

5 files changed

+27
-59
lines changed

test/inductor/test_combo_kernels.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2):
296296

297297
self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
298298

299-
@requires_cuda
300-
def test_persistent_reduction_no_x_dim(self):
301-
def fn(x, y):
302-
return x.sum(1), y.sum(1)
303-
304-
inps = (
305-
torch.rand(16, 256, device="cuda"),
306-
torch.rand(32, 256, device="cuda"),
307-
)
308-
torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
309-
torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
310-
out_eager = fn(*inps)
311-
out_compiled = torch.compile(fn)(*inps)
312-
313-
self.assertEqual(out_eager, out_compiled)
314-
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
315-
316299

317300
@instantiate_parametrized_tests
318301
class ComboKernelDynamicShapesTests(TestCase):

test/inductor/test_torchinductor_strided_blocks.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -706,31 +706,6 @@ def test_2d_reduction_odd_shapes(
706706
# Check the code for multiple Rn_BLOCK's
707707
self._assert_reduction_ndims(code, 2)
708708

709-
def test_2d_reduction_no_x_dim(self):
710-
"""
711-
Tests a 2D reduction without an "x" dimension.
712-
"""
713-
# We need a size to get no x dim.
714-
view = self._discontiguous_tensor((2, 346), self.device)
715-
716-
# Expect 1 block pointer for the input.
717-
result, (code,) = run_and_compare(
718-
self,
719-
torch.prod,
720-
view,
721-
expected_num_block_pointers=1,
722-
expected_num_triton_kernels=1,
723-
config_patches=tiled_reduction_config,
724-
)
725-
726-
# Check that there's no X dimension in the signature.
727-
(signature_line,) = (
728-
line for line in code.splitlines() if line.startswith("def triton")
729-
)
730-
self.assertNotIn("BLOCK", signature_line)
731-
732-
# Check for 2 reduction dimensions in the body.
733-
self._assert_reduction_ndims(code, 2)
734709

735710
@parametrize(
736711
"size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback",

torch/_inductor/choices.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,11 @@ def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
215215
Heuristic to decide if we should drop the X dimension from a persistent reduction kernel.
216216
So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1.
217217
Strangely this is faster than a [1, RBLOCK] block in some cases.
218+
219+
ROCm branch change: Remove want_no_x_dim for persistent reduction.
220+
Inductor benchmarks show no perf advantage and simplifies autotune flow.
218221
"""
219-
return (
220-
features.get_reduction_hint() == ReductionHint.INNER
221-
and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256)
222-
)
222+
return False
223223

224224
@staticmethod
225225
def reduction_split_factor(

torch/_inductor/codegen/triton.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1768,13 +1768,10 @@ def should_use_persistent_reduction(self) -> bool:
17681768
)
17691769

17701770
def want_no_x_dim(self):
1771-
if (
1772-
self.persistent_reduction
1773-
and len(self.numels) == self.num_reduction_dims + 1
1774-
):
1775-
if self.fixed_config:
1776-
return self.fixed_config["XBLOCK"] == 1
1777-
return V.choices.want_no_x_dim(self.features)
1771+
"""
1772+
ROCm branch change: Remove want_no_x_dim for persistent reduction.
1773+
Inductor benchmarks show no perf advantage and simplifies autotune flow.
1774+
"""
17781775
return False
17791776

17801777
@property

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2556,6 +2556,10 @@ def _persistent_reduction_configs(
25562556
rnumel = get_total_reduction_numel(size_hints)
25572557

25582558
MAX_PERSISTENT_BLOCK_NUMEL = 4096
2559+
max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or (
2560+
inductor_meta.get("max_autotune")
2561+
or inductor_meta.get("max_autotune_pointwise")
2562+
)
25592563

25602564
if "y" not in size_hints:
25612565
configs = [
@@ -2585,18 +2589,27 @@ def _persistent_reduction_configs(
25852589
if "y" in size_hints:
25862590
pass
25872591
# TODO(jansel): we should be able to improve these heuristics
2588-
elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
2589-
configs = configs[:1]
2590-
elif reduction_hint == ReductionHint.OUTER:
2591-
configs = configs[-1:]
2592-
elif reduction_hint == ReductionHint.OUTER_TINY:
2593-
configs = [
2592+
if not max_autotune_enabled: # Don't filter if tuning enabled
2593+
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
2594+
configs = configs[:1]
2595+
elif reduction_hint == ReductionHint.OUTER:
2596+
configs = configs[-1:]
2597+
2598+
if reduction_hint == ReductionHint.OUTER_TINY:
2599+
tiny_configs = [
25942600
triton_config_reduction(
25952601
size_hints,
25962602
2 * (256 // rnumel) if rnumel <= 256 else 1,
25972603
rnumel,
25982604
)
25992605
]
2606+
if max_autotune_enabled:
2607+
for tconfig in tiny_configs:
2608+
if tconfig not in configs:
2609+
configs.append(tconfig)
2610+
else:
2611+
configs = tiny_configs
2612+
26002613
for c in configs:
26012614
# we don't need Rn_BLOCK for persistent reduction
26022615
for prefix in size_hints:

0 commit comments

Comments
 (0)