Skip to content

Commit 7de1214

Browse files
naromero77amdjatayloAmdSampsa
authored
[NO CP][release/2.7][ROCm][inductor] Inductor heuristic upstream backports (#2807)
These are backports based on these upstream PRs. Cherrypicks were performed when they where possible. pytorch#163908 (persistent reduction autotune) pytorch#161280 (reduction) pytorch#162053 (foreach) pytorch#163197 (pointwise) pytorch#166470 (pointwise config for atomic add) Also included are some additional customer-specific configs which were not upstreamed but are in this backport to 2.9 #2723 Did not backport filter functions such as ` _maybe_filter_configs_for_tma_restrictions` https://github.com/ROCm/pytorch/blob/release/2.9/torch/_inductor/runtime/triton_heuristics.py#L2614 --------- Co-authored-by: Jack Taylor <[email protected]> Co-authored-by: Jack Taylor <[email protected]> Co-authored-by: Sampsa Riikonen <[email protected]> Co-authored-by: AmdSampsa <[email protected]>
1 parent 175d5a6 commit 7de1214

File tree

5 files changed

+448
-59
lines changed

5 files changed

+448
-59
lines changed

torch/_inductor/codegen/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1873,6 +1873,7 @@ def __init__(
18731873
self.compute = IndentedBuffer()
18741874
self.stores = IndentedBuffer()
18751875

1876+
self.atomic_add_found = False
18761877
self.num_load = 0
18771878
self.num_reduction = 0
18781879

torch/_inductor/codegen/triton.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,11 +1002,17 @@ def relu(x):
10021002

10031003
@staticmethod
10041004
def minimum(a, b):
1005-
return f"triton_helpers.minimum({a}, {b})"
1005+
if torch.version.hip:
1006+
return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)"
1007+
else:
1008+
return f"triton_helpers.minimum({a}, {b})"
10061009

10071010
@staticmethod
10081011
def maximum(a, b):
1009-
return f"triton_helpers.maximum({a}, {b})"
1012+
if torch.version.hip:
1013+
return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)"
1014+
else:
1015+
return f"triton_helpers.maximum({a}, {b})"
10101016

10111017
@staticmethod
10121018
def where(a, b, c):
@@ -1202,7 +1208,10 @@ def load_seed(name, offset):
12021208
@staticmethod
12031209
@maybe_upcast_float32()
12041210
def rsqrt(x):
1205-
return f"libdevice.rsqrt({x})"
1211+
if torch.version.hip:
1212+
return f"tl.rsqrt({x})"
1213+
else:
1214+
return f"libdevice.rsqrt({x})"
12061215

12071216
@staticmethod
12081217
@maybe_upcast_float32()
@@ -2284,6 +2293,7 @@ def store(
22842293
elif mode is None:
22852294
line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})"
22862295
elif mode == "atomic_add":
2296+
self.atomic_add_found = True
22872297
line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str}, sem='relaxed')"
22882298
else:
22892299
raise NotImplementedError(f"store mode={mode}")
@@ -3226,8 +3236,9 @@ def codegen_body(self):
32263236
loop_end = (
32273237
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
32283238
)
3239+
num_stages = ", num_stages = 2" if torch.version.hip else ""
32293240
self.body.writeline(
3230-
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
3241+
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):"
32313242
)
32323243
with self.body.indent(offset=level + 1):
32333244
self.iteration_ranges_codegen_header(tree, self.body)
@@ -3600,6 +3611,7 @@ def add_constexpr_arg(arg_name):
36003611
"mutated_arg_names": mutated_args,
36013612
"optimize_mem": optimize_mem,
36023613
"no_x_dim": self.no_x_dim,
3614+
"atomic_add_found": self.atomic_add_found,
36033615
"num_load": self.num_load,
36043616
"num_reduction": self.num_reduction,
36053617
**self.inductor_meta_common(),

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,7 @@ class triton:
11351135
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
11361136
# Raise the threshold to 16 to be safe.
11371137
# We should revisit this once we understand more of the source of register spills.
1138-
spill_threshold: int = 16
1138+
spill_threshold: int = 32 if torch.version.hip else 16
11391139

11401140
# Generate code containing the newer tl.make_block_ptr() API for loads/store
11411141
use_block_ptr = False

torch/_inductor/runtime/hints.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@
77
from enum import auto, Enum
88
from typing import Optional, Union
99

10+
import torch
1011
from torch.utils._triton import has_triton_package
1112

1213

1314
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
1415
# NOTE: if these fail asserts submit a PR to increase them
1516
TRITON_MAX_BLOCK = {
16-
"X": 4096,
17+
"X": 8192 if torch.version.hip else 4096,
1718
"Y": 1024,
1819
"Z": 1024,
1920
"R0_": 4096 * 16, # * 16 is multi-kernel only

0 commit comments

Comments
 (0)