File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change 2626from torch ._prims_common import is_integer_dtype
2727from torch .utils ._ordered_set import OrderedSet
2828from torch .utils ._sympy .functions import CeilDiv , FloorDiv , ModularIndexing
29- from torch .utils ._triton import has_triton_package
29+ from torch .utils ._triton import has_triton_package , get_triton_version
3030
3131from ...utils ._sympy .symbol import free_symbol_is_type , prefix_str , symbol_is_type , SymT
3232from ...utils ._sympy .value_ranges import ValueRanges
@@ -1232,9 +1232,9 @@ def tan(x):
12321232 @staticmethod
12331233 @maybe_upcast_float32 ()
12341234 def tanh (x ):
1235- # On ROCm, always use fast_tanhf
1236- # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+
1237- if torch . version . hip :
1235+ if torch . version . hip and get_triton_version () > ( 3 , 2 ):
1236+ # On ROCm, use fast_tanhf depending on Triton version
1237+ # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+
12381238 return f"libdevice.fast_tanhf({ x } )"
12391239 else :
12401240 return f"libdevice.tanh({ x } )"
You can’t perform that action at this time.
0 commit comments