diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index a1e9df4725c57..8fcbc3de469f4 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -d704bc6e69c1a588c8edd3cbb67505d554ed65f6 +5df9c723de8c23508773b07fe16dd34e4c444541 diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 17a336cc3cf2e..d55fcf4df449d 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -26,7 +26,7 @@ from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing -from torch.utils._triton import has_triton_package, has_triton_stable_tma_api +from torch.utils._triton import has_triton_package, has_triton_stable_tma_api, get_triton_version from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges @@ -1315,7 +1315,12 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - return f"libdevice.fast_tanhf({x})" + if torch.version.hip and get_triton_version() > (3, 2): + # On ROCm, use fast_tanhf depending on Triton version + # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ + return f"libdevice.fast_tanhf({x})" + else: + return f"libdevice.tanh({x})" @staticmethod @maybe_upcast_float32()