Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/triton.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
d704bc6e69c1a588c8edd3cbb67505d554ed65f6
5df9c723de8c23508773b07fe16dd34e4c444541
9 changes: 7 additions & 2 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from torch._prims_common import is_integer_dtype
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
from torch.utils._triton import has_triton_package, has_triton_stable_tma_api
from torch.utils._triton import has_triton_package, has_triton_stable_tma_api, get_triton_version

from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
from ...utils._sympy.value_ranges import ValueRanges
Expand Down Expand Up @@ -1315,7 +1315,12 @@ def tan(x):
@staticmethod
@maybe_upcast_float32()
def tanh(x):
return f"libdevice.fast_tanhf({x})"
if torch.version.hip and get_triton_version() > (3, 2):
# On ROCm, use fast_tanhf depending on Triton version
# Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+
return f"libdevice.fast_tanhf({x})"
else:
return f"libdevice.tanh({x})"

@staticmethod
@maybe_upcast_float32()
Expand Down