Skip to content

Commit 23c0876

Browse files
jataylok50112113jithunnair-amd
authored
Triton bump in 7.1_internal_testing (#2479)
Bump to triton pytorch/rocm7.1_internal_testing for gfx950 related improvements - https://github.com/ROCm/triton/tree/pytorch/rocm7.1_internal_testing --------- Co-authored-by: ShaoChunLee <[email protected]> Co-authored-by: Jithun Nair <[email protected]>
1 parent c1ee54d commit 23c0876

File tree

9 files changed

+64
-75
lines changed

9 files changed

+64
-75
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
711e2a92522e0a9921ce58ae658571ca55c49b97
1+
b2c0ea435ece3491b2940af7c08d42974b953e06

test/inductor/test_combo_kernels.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -296,23 +296,6 @@ def fn(a0, a1, a2, b0, b1, b2):
296296

297297
self.assertTrue(7 <= torch._inductor.metrics.generated_kernel_count <= 8)
298298

299-
@requires_cuda
300-
def test_persistent_reduction_no_x_dim(self):
301-
def fn(x, y):
302-
return x.sum(1), y.sum(1)
303-
304-
inps = (
305-
torch.rand(16, 256, device="cuda"),
306-
torch.rand(32, 256, device="cuda"),
307-
)
308-
torch._dynamo.mark_dynamic(inps[0], 0, min=1, max=256)
309-
torch._dynamo.mark_dynamic(inps[1], 0, min=1, max=256)
310-
out_eager = fn(*inps)
311-
out_compiled = torch.compile(fn)(*inps)
312-
313-
self.assertEqual(out_eager, out_compiled)
314-
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)
315-
316299

317300
@instantiate_parametrized_tests
318301
class ComboKernelDynamicShapesTests(TestCase):

test/inductor/test_torchinductor_strided_blocks.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -746,31 +746,6 @@ def test_2d_reduction_odd_shapes(
746746
# Check the code for multiple Rn_BLOCK's
747747
self._assert_reduction_ndims(code, 2)
748748

749-
def test_2d_reduction_no_x_dim(self):
750-
"""
751-
Tests a 2D reduction without an "x" dimension.
752-
"""
753-
# We need a size to get no x dim.
754-
view = self._discontiguous_tensor((2, 346), self.device)
755-
756-
# Expect 1 block pointer for the input.
757-
result, (code,) = self._run_and_compare(
758-
torch.prod,
759-
view,
760-
expected_num_block_pointers=1,
761-
expected_num_triton_kernels=1,
762-
config_patches=tiled_reduction_config,
763-
)
764-
765-
# Check that there's no X dimension in the signature.
766-
(signature_line,) = (
767-
line for line in code.splitlines() if line.startswith("def triton")
768-
)
769-
self.assertNotIn("BLOCK", signature_line)
770-
771-
# Check for 2 reduction dimensions in the body.
772-
self._assert_reduction_ndims(code, 2)
773-
774749
@parametrize(
775750
"size,expected_num_block_pointers,expected_num_triton_kernels,expect_fallback",
776751
[

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -457,15 +457,15 @@ def get_signature_value(idx: int, arg: Any) -> str:
457457
inspect.signature(backend.get_codegen_implementation).parameters
458458
)
459459
if make_ir_sig_params == 2:
460-
ttir_module = src.make_ir(options, context)
460+
ttir_module = src.make_ir(target, options, context)
461461
elif make_ir_sig_params == 3:
462462
codegen_fns = backend.get_codegen_implementation()
463-
ttir_module = src.make_ir(options, codegen_fns, context)
463+
ttir_module = src.make_ir(target, options, codegen_fns, context)
464464
else:
465465
codegen_args = [options] if get_codegen_implementation_sig_params == 1 else []
466466
codegen_fns = backend.get_codegen_implementation(*codegen_args)
467467
module_map = backend.get_module_map()
468-
ttir_module = src.make_ir(options, codegen_fns, module_map, context)
468+
ttir_module = src.make_ir(target, options, codegen_fns, module_map, context)
469469
if not ttir_module.verify():
470470
raise RuntimeError("Verification for TTIR module has failed")
471471

torch/_inductor/choices.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,11 @@ def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
202202
Heuristic to decide if we should drop the X dimension from a persistent reduction kernel.
203203
So the [XBLOCK, RBLOCK] block becomes a [RBLOCK] block and XBLOCK is forced to be always 1.
204204
Strangely this is faster than a [1, RBLOCK] block in some cases.
205+
206+
ROCm branch change: Remove want_no_x_dim for persistent reduction.
207+
Inductor benchmarks show no perf advantage and simplifies autotune flow.
205208
"""
206-
return (
207-
features.get_reduction_hint() == ReductionHint.INNER
208-
and V.graph.sizevars.statically_known_geq(features.reduction_numel, 256)
209-
)
209+
return False
210210

211211
@staticmethod
212212
def reduction_split_factor(

torch/_inductor/codegen/triton.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,7 +1288,7 @@ def tan(x):
12881288
@staticmethod
12891289
@maybe_upcast_float32()
12901290
def tanh(x):
1291-
return f"libdevice.tanh({x})"
1291+
return f"libdevice.fast_tanhf({x})"
12921292

12931293
@staticmethod
12941294
@maybe_upcast_float32()
@@ -1999,13 +1999,10 @@ def should_use_persistent_reduction(self) -> bool:
19991999
)
20002000

20012001
def want_no_x_dim(self):
2002-
if (
2003-
self.persistent_reduction
2004-
and len(self.numels) == self.num_reduction_dims + 1
2005-
):
2006-
if self.fixed_config:
2007-
return self.fixed_config["XBLOCK"] == 1
2008-
return V.choices.want_no_x_dim(self.features)
2002+
"""
2003+
ROCm branch change: Remove want_no_x_dim for persistent reduction.
2004+
Inductor benchmarks show no perf advantage and simplifies autotune flow.
2005+
"""
20092006
return False
20102007

20112008
@property

torch/_inductor/codegen/triton_combo_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -614,7 +614,7 @@ def jit_line(
614614
if heuristics == "foreach":
615615
heuristics_line = f"""
616616
@triton_heuristics.foreach(
617-
num_warps={self.num_warps},
617+
filename=__file__,
618618
triton_meta={triton_meta!r},
619619
inductor_meta={inductor_meta!r},
620620
)

torch/_inductor/runtime/coordinate_descent_tuner.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import itertools
44
import logging
55
from typing import Callable, Optional, TYPE_CHECKING
6+
from functools import lru_cache
67

78
from .hints import TRITON_MAX_BLOCK
89
from .runtime_utils import red_text, triton_config_to_hashable
@@ -60,10 +61,16 @@ def get_config_max(self, prefix: str) -> int:
6061
size_hint = self.size_hints.get(prefix) if self.size_hints is not None else None
6162
return min(max_block, size_hint) if size_hint is not None else max_block
6263

64+
@lru_cache(maxsize=1)
6365
def get_warpsmax(self):
64-
# Currently, CUDA has a maximum of 1024 threads, so 32 is the max
65-
# number of warps.
66-
return 1024 // 32
66+
# CUDA/ROCm has a maximum of 1024 threads per block
67+
from torch.cuda import current_device, get_device_properties, is_available
68+
69+
warp_size = (
70+
get_device_properties(current_device()).warp_size if is_available() else 32
71+
)
72+
73+
return 1024 // warp_size
6774

6875
def cache_benchmark_result(self, config, timing):
6976
self.cached_benchmark_results[triton_config_to_hashable(config)] = timing

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 40 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2689,21 +2689,30 @@ def _persistent_reduction_configs(
26892689
xnumel = size_hints["x"]
26902690
rnumel = get_total_reduction_numel(size_hints)
26912691

2692-
MAX_PERSISTENT_BLOCK_NUMEL = 4096
2692+
max_autotune_enabled = not disable_pointwise_autotuning(inductor_meta) or (
2693+
inductor_meta.get("max_autotune")
2694+
or inductor_meta.get("max_autotune_pointwise")
2695+
)
26932696

2697+
configs = [
2698+
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
2699+
for xblock in (1, 8, 32, 128)
2700+
if xblock == 1 or (xblock <= xnumel and (max_autotune_enabled or rnumel * xblock <= 4096))
2701+
]
2702+
26942703
if "y" not in size_hints:
26952704
configs = [
26962705
triton_config_reduction(size_hints, xblock, rnumel, register_intensive=True)
26972706
for xblock in (1, 8, 32, 128)
26982707
if xblock == 1
2699-
or (rnumel * xblock <= MAX_PERSISTENT_BLOCK_NUMEL and xblock <= xnumel)
2708+
or (rnumel * xblock <= 4096 and xblock <= xnumel)
27002709
]
27012710
else:
27022711
configs = []
27032712
assert "tiling_scores" in inductor_meta
27042713
x_y_scores = {dim: inductor_meta["tiling_scores"][dim] for dim in ("x", "y")}
27052714
for target_block_size in (1, 8, 32, 64, 128):
2706-
if target_block_size * rnumel > MAX_PERSISTENT_BLOCK_NUMEL:
2715+
if target_block_size * rnumel > 4096:
27072716
continue
27082717

27092718
block_sizes = match_target_block_product(
@@ -2718,19 +2727,28 @@ def _persistent_reduction_configs(
27182727
# defer to more autotuning, initially
27192728
if "y" in size_hints:
27202729
pass
2721-
# TODO(jansel): we should be able to improve these heuristics
2722-
elif reduction_hint == ReductionHint.INNER and rnumel >= 256:
2723-
configs = configs[:1]
2724-
elif reduction_hint == ReductionHint.OUTER:
2725-
configs = configs[-1:]
2726-
elif reduction_hint == ReductionHint.OUTER_TINY:
2727-
configs = [
2730+
2731+
if not max_autotune_enabled: # Don't filter if tuning enabled
2732+
if reduction_hint == ReductionHint.INNER and rnumel >= 256:
2733+
configs = configs[:1]
2734+
elif reduction_hint == ReductionHint.OUTER:
2735+
configs = configs[-1:]
2736+
2737+
if reduction_hint == ReductionHint.OUTER_TINY:
2738+
tiny_configs = [
27282739
triton_config_reduction(
27292740
size_hints,
27302741
2 * (256 // rnumel) if rnumel <= 256 else 1,
27312742
rnumel,
27322743
)
27332744
]
2745+
if max_autotune_enabled:
2746+
for tconfig in tiny_configs:
2747+
if tconfig not in configs:
2748+
configs.append(tconfig)
2749+
else:
2750+
configs = tiny_configs
2751+
27342752
for c in configs:
27352753
# we don't need Rn_BLOCK for persistent reduction
27362754
for prefix in size_hints:
@@ -2922,20 +2940,29 @@ def user_autotune(
29222940
)
29232941

29242942

2925-
def foreach(triton_meta, num_warps, filename=None, inductor_meta=None):
2943+
def foreach(triton_meta, filename=None, inductor_meta=None):
29262944
"""
29272945
Compile a triton foreach kernel
29282946
"""
2947+
configs = []
2948+
if disable_pointwise_autotuning(inductor_meta) and not (
2949+
inductor_meta.get("max_autotune") or
2950+
inductor_meta.get("max_autotune_pointwise")
2951+
):
2952+
configs.append(triton.Config({}, num_stages=1, num_warps=8))
2953+
else:
2954+
for warps in [1, 2, 4, 8]:
2955+
configs.append(triton.Config({}, num_stages=1, num_warps=warps))
2956+
29292957
return cached_autotune(
29302958
None,
2931-
[triton.Config({}, num_stages=1, num_warps=num_warps)],
2959+
configs,
29322960
triton_meta=triton_meta,
29332961
inductor_meta=inductor_meta,
29342962
heuristic_type=HeuristicType.TEMPLATE,
29352963
filename=filename,
29362964
)
29372965

2938-
29392966
@dataclasses.dataclass
29402967
class GridExpr:
29412968
"""Generate code for grid size expressions in launcher"""

0 commit comments

Comments
 (0)