From be87a96d57a0ab1395c4bb1ca28a753452b673c4 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Wed, 12 Nov 2025 22:07:15 +0000 Subject: [PATCH 1/3] On ROCm, always use fast_tanhf for triton codegen. (cherry picked from commit 7c5277f22a9917902d962f61792e331dfd93cd64) --- torch/_inductor/codegen/triton.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 17a336cc3cf2e..25734fe080bb3 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1315,7 +1315,12 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - return f"libdevice.fast_tanhf({x})" + # On ROCm, always use fast_tanhf + # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ + if torch.version.hip: + return f"libdevice.fast_tanhf({x})" + else: + return f"libdevice.tanh({x})" @staticmethod @maybe_upcast_float32() From 0b59f1c2c8cbe8aeb86ce9a5d6aa471f75e76091 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Thu, 13 Nov 2025 01:56:46 +0000 Subject: [PATCH 2/3] Bump up Triton commit to support fast_tanhf. --- .ci/docker/ci_commit_pins/triton.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index a1e9df4725c57..8fcbc3de469f4 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -d704bc6e69c1a588c8edd3cbb67505d554ed65f6 +5df9c723de8c23508773b07fe16dd34e4c444541 From 427f9b0052f4e799d68a25f40715c4cb0cd3bac3 Mon Sep 17 00:00:00 2001 From: "Nichols A. Romero" Date: Fri, 14 Nov 2025 20:27:49 +0000 Subject: [PATCH 3/3] Conditionalize fast_tanhf on triton_version. (cherry picked from commit f416c7119ad1443bf022a37a8f3f21b201aa4bbc) --- torch/_inductor/codegen/triton.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 25734fe080bb3..d55fcf4df449d 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -26,7 +26,7 @@ from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing -from torch.utils._triton import has_triton_package, has_triton_stable_tma_api +from torch.utils._triton import has_triton_package, has_triton_stable_tma_api, get_triton_version from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges @@ -1315,9 +1315,9 @@ def tan(x): @staticmethod @maybe_upcast_float32() def tanh(x): - # On ROCm, always use fast_tanhf - # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ - if torch.version.hip: + if torch.version.hip and get_triton_version() > (3, 2): + # On ROCm, use fast_tanhf depending on Triton version + # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+ return f"libdevice.fast_tanhf({x})" else: return f"libdevice.tanh({x})"