Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
24717c3
Naive foreach autotune support
jataylo Sep 3, 2025
ec5765e
Update triton_heuristics.py
jataylo Sep 4, 2025
a49093e
Update triton_heuristics.py
jataylo Sep 9, 2025
1363a1c
Linting
jataylo Sep 9, 2025
f80f0b8
Conditionalize xblock values for HIP.
naromero77amd Nov 14, 2025
5ed2f82
Create tiny_config.
naromero77amd Sep 25, 2025
c269eb3
pointwise autotuning returnz (#2636)
AmdSampsa Sep 12, 2025
d13e652
a few more configs for 2d grids (#2649)
AmdSampsa Sep 17, 2025
a13015c
[ROCm][inductor] Additional pointwise tunings (#2642)
naromero77amd Sep 17, 2025
3d716eb
pointwise config with MAX_BLOCK.
naromero77amd Sep 17, 2025
7edd183
Update triton_config method.
naromero77amd Sep 17, 2025
adc6a45
Increase TRITON_MAX_BLOCK X value.
naromero77amd Sep 18, 2025
b1f1c36
even more configs
AmdSampsa Sep 19, 2025
26f8133
More ROCm conditionalisation
jataylo Sep 23, 2025
c1f8e99
Lint.
naromero77amd Sep 25, 2025
0a9dcdd
Reduction heursitics improvements for ROCm
naromero77amd Aug 22, 2025
768ab84
Add PropagateNan argument to minimum and maximum function.
naromero77amd Oct 2, 2025
451c2b4
New WRT configs for autotuning (#2708)
AmdSampsa Oct 14, 2025
7850a9c
Align TRITON_MAX_BLOCK code with upstream.
naromero77amd Nov 15, 2025
dbdb554
Fix indentation.
naromero77amd Nov 15, 2025
83e3be0
Backport InductorConfig class.
naromero77amd Nov 18, 2025
d235a15
Patch get_args_with_constexprs
naromero77amd Nov 19, 2025
627a571
[NO CP] triton sanity check for 2D POI (#2798)
AmdSampsa Nov 14, 2025
c48cc4c
Config specific to 1D pointwise kernels with atomic_add.
naromero77amd Nov 21, 2025
e824f40
Lint
naromero77amd Nov 21, 2025
ef5f375
No change, just re-ordering of configs to match upstream.
naromero77amd Nov 21, 2025
3ad2e71
Remove fbcode helper function.
naromero77amd Nov 21, 2025
b1cdd55
Missing loads_and_red variable.
naromero77amd Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,6 +1873,7 @@ def __init__(
self.compute = IndentedBuffer()
self.stores = IndentedBuffer()

self.atomic_add_found = False
self.num_load = 0
self.num_reduction = 0

Expand Down
20 changes: 16 additions & 4 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -1002,11 +1002,17 @@ def relu(x):

@staticmethod
def minimum(a, b):
return f"triton_helpers.minimum({a}, {b})"
if torch.version.hip:
return f"tl.minimum({a}, {b}, tl.PropagateNan.ALL)"
else:
return f"triton_helpers.minimum({a}, {b})"

@staticmethod
def maximum(a, b):
return f"triton_helpers.maximum({a}, {b})"
if torch.version.hip:
return f"tl.maximum({a}, {b}, tl.PropagateNan.ALL)"
else:
return f"triton_helpers.maximum({a}, {b})"

@staticmethod
def where(a, b, c):
Expand Down Expand Up @@ -1202,7 +1208,10 @@ def load_seed(name, offset):
@staticmethod
@maybe_upcast_float32()
def rsqrt(x):
return f"libdevice.rsqrt({x})"
if torch.version.hip:
return f"tl.rsqrt({x})"
else:
return f"libdevice.rsqrt({x})"

@staticmethod
@maybe_upcast_float32()
Expand Down Expand Up @@ -2285,6 +2294,7 @@ def store(
elif mode is None:
line = f"tl.store({var} + ({indexing.index_str}), {value}, {indexing.mask_str})"
elif mode == "atomic_add":
self.atomic_add_found = True
line = f"tl.atomic_add({var} + ({indexing.index_str}), {value}, {indexing.mask_str}, sem='relaxed')"
else:
raise NotImplementedError(f"store mode={mode}")
Expand Down Expand Up @@ -3227,8 +3237,9 @@ def codegen_body(self):
loop_end = (
"rsplit_end" if self.cooperative_reduction else f"{prefix}numel"
)
num_stages = ", num_stages = 2" if torch.version.hip else ""
self.body.writeline(
f"for {prefix}offset in range({loop_start}, {loop_end}, {prefix.upper()}BLOCK):"
f"for {prefix}offset in tl.range({loop_start}, {loop_end}, {prefix.upper()}BLOCK{num_stages}):"
)
with self.body.indent(offset=level + 1):
self.iteration_ranges_codegen_header(tree, self.body)
Expand Down Expand Up @@ -3601,6 +3612,7 @@ def add_constexpr_arg(arg_name):
"mutated_arg_names": mutated_args,
"optimize_mem": optimize_mem,
"no_x_dim": self.no_x_dim,
"atomic_add_found": self.atomic_add_found,
"num_load": self.num_load,
"num_reduction": self.num_reduction,
**self.inductor_meta_common(),
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,7 @@ class triton:
# So far we see a fixed 8 spilled registers for kernels using sin/cos.
# Raise the threshold to 16 to be safe.
# We should revisit this once we understand more of the source of register spills.
spill_threshold: int = 16
spill_threshold: int = 32 if torch.version.hip else 16

# Generate code containing the newer tl.make_block_ptr() API for loads/store
use_block_ptr = False
Expand Down
3 changes: 2 additions & 1 deletion torch/_inductor/runtime/hints.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
from enum import auto, Enum
from typing import Optional, Union

import torch
from torch.utils._triton import has_triton_package


# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
# NOTE: if these fail asserts submit a PR to increase them
TRITON_MAX_BLOCK = {
"X": 4096,
"X": 8192 if torch.version.hip else 4096,
"Y": 1024,
"Z": 1024,
"R0_": 4096 * 16, # * 16 is multi-kernel only
Expand Down
Loading