Skip to content

Commit bf92b1a

Browse files
jataylonaromero77amdAmdSampsa
authored andcommitted
[Release 2.9] Inductor perf cherry picks (#2723)
These changes are currently in progress of being upstreamed. Bring into release 2.9 for customer model perf improvement --------- Co-authored-by: Nichols A. Romero <[email protected]> Co-authored-by: Sampsa Riikonen <[email protected]> Co-authored-by: Nichols A. Romero <[email protected]> Co-authored-by: AmdSampsa <[email protected]>
1 parent 839c2fd commit bf92b1a

File tree

5 files changed

+201
-63
lines changed

5 files changed

+201
-63
lines changed

test/inductor/test_async_compile.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,14 @@ def f(a, b):
7474
return (a @ b).to(torch.float32).sum(dim=1)
7575

7676
# Fake name to make sure the lookup table is name agnostic
77-
func_def = """
77+
# When codegen/triton.py is changed, func_def must be updated
78+
loop_header = (
79+
"for r0_offset in tl.range(0, r0_numel, R0_BLOCK, num_stages = 2):"
80+
if torch.version.hip
81+
else "for r0_offset in range(0, r0_numel, R0_BLOCK):"
82+
)
83+
84+
func_def = f"""
7885
def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
7986
xnumel = 1024
8087
r0_numel = 11776
@@ -87,7 +94,7 @@ def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.cons
8794
rbase = r0_base
8895
x0 = xindex
8996
_tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
90-
for r0_offset in range(0, r0_numel, R0_BLOCK):
97+
{loop_header}
9198
r0_index = r0_offset + r0_base
9299
r0_mask = r0_index < r0_numel
93100
roffset = r0_offset

torch/_inductor/codegen/triton.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,11 +1101,17 @@ def relu(x):
11011101

11021102
@staticmethod
11031103
def minimum(a, b):
1104-
return f"triton_helpers.minimum({a}, {b})"
1104+
if torch.version.hip:
1105+
return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)"
1106+
else:
1107+
return f"triton_helpers.minimum({a}, {b})"
11051108

11061109
@staticmethod
11071110
def maximum(a, b):
1108-
return f"triton_helpers.maximum({a}, {b})"
1111+
if torch.version.hip:
1112+
return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)"
1113+
else:
1114+
return f"triton_helpers.maximum({a}, {b})"
11091115

11101116
@staticmethod
11111117
def where(a, b, c):
@@ -1291,7 +1297,10 @@ def load_seed(name, offset):
12911297
@staticmethod
12921298
@maybe_upcast_float32()
12931299
def rsqrt(x):
1294-
return f"libdevice.rsqrt({x})"
1300+
if torch.version.hip:
1301+
return f"tl.rsqrt({x})"
1302+
else:
1303+
return f"libdevice.rsqrt({x})"
12951304

12961305
@staticmethod
12971306
@maybe_upcast_float32()
@@ -3788,8 +3797,9 @@ def codegen_body(self):
37883797
loop_end = (
37893798
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
37903799
)
3800+
num_stages = ", num_stages = 2" if torch.version.hip else ""
37913801
self.body.writeline(
3792-
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
3802+
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):"
37933803
)
37943804
with self.body.indent(offset=level + 1):
37953805
self.iteration_ranges_codegen_header(tree, self.body)

torch/_inductor/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1395,7 +1395,7 @@ class triton:
13951395
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
13961396
# Raise the threshold to 16 to be safe.
13971397
# We should revisit this once we understand more of the source of register spills.
1398-
spill_threshold: int = 16
1398+
spill_threshold: int = 32 if torch.version.hip else 16
13991399

14001400
# Generate code containing the newer tl.make_block_ptr() API for loads/store
14011401
use_block_ptr = False

torch/_inductor/runtime/hints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
1414
# NOTE: if these fail asserts submit a PR to increase them
1515
TRITON_MAX_BLOCK = {
16-
"X": 4096,
16+
"X": 8192,
1717
"Y": 1024,
1818
"Z": 1024,
1919
"R0_": 4096 * 16, # * 16 is multi-kernel only

0 commit comments

Comments
 (0)