Skip to content

Commit 427f9b0

Browse files
committed
Conditionalize fast_tanhf on triton_version.
(cherry picked from commit f416c71)
1 parent 0b59f1c commit 427f9b0

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch._prims_common import is_integer_dtype
2727
from torch.utils._ordered_set import OrderedSet
2828
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
29-
from torch.utils._triton import has_triton_package, has_triton_stable_tma_api
29+
from torch.utils._triton import has_triton_package, has_triton_stable_tma_api, get_triton_version
3030

3131
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
3232
from ...utils._sympy.value_ranges import ValueRanges
@@ -1315,9 +1315,9 @@ def tan(x):
13151315
@staticmethod
13161316
@maybe_upcast_float32()
13171317
def tanh(x):
1318-
# On ROCm, always use fast_tanhf
1319-
# Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+
1320-
if torch.version.hip:
1318+
if torch.version.hip and get_triton_version() > (3, 2):
1319+
# On ROCm, use fast_tanhf depending on Triton version
1320+
# Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+
13211321
return f"libdevice.fast_tanhf({x})"
13221322
else:
13231323
return f"libdevice.tanh({x})"

0 commit comments

Comments
 (0)