Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
d3d77f5
Naive foreach autotune support
jataylo Sep 3, 2025
528cf02
Update triton_heuristics.py
jataylo Sep 4, 2025
11e1dfc
Update triton_heuristics.py
jataylo Sep 9, 2025
0cf1c89
Linting
jataylo Sep 9, 2025
547e75f
Add max_autotune_enabled boolean.
naromero77amd Sep 25, 2025
733c36a
Conditionalize xblock values for HIP.
naromero77amd Sep 25, 2025
dee2fdf
Create tiny_config.
naromero77amd Sep 25, 2025
9b2cf21
Do not filter configs when tuning.
naromero77amd Sep 25, 2025
691a328
Add tiny_config when autotune enabled.
naromero77amd Sep 25, 2025
1883479
Fix persistent reduction conflicts
jataylo Oct 16, 2025
b9e0182
pointwise autotuning returnz (#2636)
AmdSampsa Sep 12, 2025
04aa3e4
a few more configs for 2d grids (#2649)
AmdSampsa Sep 17, 2025
af5f678
[ROCm][inductor] Additional pointwise tunings (#2642)
naromero77amd Sep 17, 2025
8bd33f9
pointwise config with MAX_BLOCK.
naromero77amd Sep 17, 2025
8f60456
Update triton_config method.
naromero77amd Sep 17, 2025
f6aaaf8
Increase TRITON_MAX_BLOCK X value.
naromero77amd Sep 18, 2025
c36d85f
even more configs
AmdSampsa Sep 19, 2025
83e453f
More ROCm conditionalisation
jataylo Sep 23, 2025
0de435f
Lint.
naromero77amd Sep 25, 2025
189481e
Reduction heursitics improvements for ROCm
naromero77amd Aug 22, 2025
eea659c
Add PropagateNan argument to minimum and maximum function.
naromero77amd Oct 2, 2025
4433700
Update autotune_lookup_tunable UT.
naromero77amd Oct 8, 2025
6534df0
New WRT configs for autotuning (#2708)
AmdSampsa Oct 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions test/inductor/test_async_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,14 @@ def f(a, b):
return (a @ b).to(torch.float32).sum(dim=1)

# Fake name to make sure the lookup table is name agnostic
func_def = """
# When codegen/triton.py is changed, func_def must be updated
loop_header = (
"for r0_offset in tl.range(0, r0_numel, R0_BLOCK, num_stages = 2):"
if torch.version.hip
else "for r0_offset in range(0, r0_numel, R0_BLOCK):"
)

func_def = f"""
def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
xnumel = 1024
r0_numel = 11776
Expand All @@ -87,7 +94,7 @@ def triton_fused_fake_name(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.cons
rbase = r0_base
x0 = xindex
_tmp3 = tl.full([XBLOCK, R0_BLOCK], 0, tl.float32)
for r0_offset in range(0, r0_numel, R0_BLOCK):
{loop_header}
r0_index = r0_offset + r0_base
r0_mask = r0_index < r0_numel
roffset = r0_offset
Expand Down
18 changes: 14 additions & 4 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,11 +1101,17 @@ def relu(x):

@staticmethod
def minimum(a, b):
return f"triton_helpers.minimum({a}, {b})"
if torch.version.hip:
return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)"
else:
return f"triton_helpers.minimum({a}, {b})"

@staticmethod
def maximum(a, b):
return f"triton_helpers.maximum({a}, {b})"
if torch.version.hip:
return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)"
else:
return f"triton_helpers.maximum({a}, {b})"

@staticmethod
def where(a, b, c):
Expand Down Expand Up @@ -1291,7 +1297,10 @@ def load_seed(name, offset):
@staticmethod
@maybe_upcast_float32()
def rsqrt(x):
return f"libdevice.rsqrt({x})"
if torch.version.hip:
return f"tl.rsqrt({x})"
else:
return f"libdevice.rsqrt({x})"

@staticmethod
@maybe_upcast_float32()
Expand Down Expand Up @@ -3788,8 +3797,9 @@ def codegen_body(self):
loop_end = (
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
)
num_stages = ", num_stages = 2" if torch.version.hip else ""
self.body.writeline(
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):"
)
with self.body.indent(offset=level + 1):
self.iteration_ranges_codegen_header(tree, self.body)
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1391,7 +1391,7 @@ class triton:
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
# Raise the threshold to 16 to be safe.
# We should revisit this once we understand more of the source of register spills.
spill_threshold: int = 16
spill_threshold: int = 32 if torch.version.hip else 16

# Generate code containing the newer tl.make_block_ptr() API for loads/store
use_block_ptr = False
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/runtime/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
# NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = {
"X": 4096,
"X": 8192,
"Y": 1024,
"Z": 1024,
"R0_": 4096 * 16, # * 16 is multi-kernel only
Expand Down
Loading