Skip to content

Commit c58ceb1

Browse files
okakarpajataylo
authored andcommitted
[AUTOGENERATED] [release/2.8] [SWDEV-539215] - Autotune support for persistent reduction and no_x_dim removal (#2454)
Cherry-pick of #2417 Need to resolve conflicts --------- Co-authored-by: Jack Taylor <[email protected]> (cherry picked from commit eb47158)
1 parent 5e67be1 commit c58ceb1

File tree

5 files changed

+37
-29
lines changed

5 files changed

+37
-29
lines changed

test/inductor/test_combo_kernels.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2):
296296

297297
self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
298298

299-
@requires_cuda_and_triton
300-
def test_persistent_reduction_no_x_dim(self):
301-
def fn(x, y):
302-
return x.sum(1), y.sum(1)
303-
304-
inps = (
305-
torch.rand(16, 256, device="cuda"),
306-
torch.rand(32, 256, device="cuda"),
307-
)
308-
torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
309-
torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
310-
out_eager = fn(*inps)
311-
out_compiled = torch.compile(fn)(*inps)
312-
313-
self.assertEqual(out_eager, out_compiled)
314-
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
315-
316299

317300
@instantiate_parametrized_tests
318301
class ComboKernelDynamicShapesTests(TestCase):

test/inductor/test_torchinductor_strided_blocks.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,6 +816,7 @@ def test_2d_reduction_odd_shapes(
816816
# Check the code for multiple Rn_BLOCK's
817817
self._assert_reduction_ndims(code, 2)
818818

819+
819820
@parametrize(
820821
"size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback",
821822
[

torch/_inductor/choices.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,18 @@ def should_use_persistent_reduction(
232232
features.reduction_numel, threshold
233233
) # type: ignore[arg-types]
234234

235+
@staticmethod
236+
def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
237+
"""
238+
Heuristic to decide if we should drop the X dimension from a persistent reduction kernel.
239+
So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1.
240+
Strangely this is faster than a [1, RBLOCK] block in some cases.
241+
242+
ROCm branch change: Remove want_no_x_dim for persistent reduction.
243+
Inductor benchmarks show no perf advantage and simplifies autotune flow.
244+
"""
245+
return False
246+
235247
@staticmethod
236248
def reduction_split_factor(
237249
device: torch.device,

torch/_inductor/codegen/triton.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2030,12 +2030,11 @@ def should_use_persistent_reduction(self) -> bool:
20302030
)
20312031

20322032
def want_no_x_dim(self):
2033-
return (
2034-
self.persistent_reduction
2035-
and len(self.numels) == self.num_reduction_dims + 1
2036-
and self.fixed_config
2037-
and self.fixed_config["XBLOCK"] == 1
2038-
)
2033+
"""
2034+
ROCm branch change: Remove want_no_x_dim for persistent reduction.
2035+
Inductor benchmarks show no perf advantage and simplifies autotune flow.
2036+
"""
2037+
return False
20392038

20402039
@property
20412040
def assert_function(self) -> str:

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2870,6 +2870,10 @@ def _persistent_reduction_configs(
28702870
rnumel = get_total_reduction_numel(size_hints)
28712871

28722872
MAX_PERSISTENT_BLOCK_NUMEL = 4096
2873+
max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or (
2874+
inductor_meta.get("max_autotune")
2875+
or inductor_meta.get("max_autotune_pointwise")
2876+
)
28732877

28742878
if "y" not in size_hints:
28752879
configs = [
@@ -2899,18 +2903,27 @@ def _persistent_reduction_configs(
28992903
if "y" in size_hints:
29002904
pass
29012905
# TODO(jansel): we should be able to improve these heuristics
2902-
elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
2903-
configs = configs[:1]
2904-
elif reduction_hint == ReductionHint.OUTER:
2905-
configs = configs[-1:]
2906-
elif reduction_hint == ReductionHint.OUTER_TINY:
2907-
configs = [
2906+
if not max_autotune_enabled: # Don't filter if tuning enabled
2907+
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
2908+
configs = configs[:1]
2909+
elif reduction_hint == ReductionHint.OUTER:
2910+
configs = configs[-1:]
2911+
2912+
if reduction_hint == ReductionHint.OUTER_TINY:
2913+
tiny_configs = [
29082914
triton_config_reduction(
29092915
size_hints,
29102916
2 * (256 // rnumel) if rnumel <= 256 else 1,
29112917
rnumel,
29122918
)
29132919
]
2920+
if max_autotune_enabled:
2921+
for tconfig in tiny_configs:
2922+
if tconfig not in configs:
2923+
configs.append(tconfig)
2924+
else:
2925+
configs = tiny_configs
2926+
29142927
for c in configs:
29152928
# we don't need Rn_BLOCK for persistent reduction
29162929
for prefix in size_hints:

0 commit comments

Comments
 (0)