File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change 2626from torch ._prims_common import is_integer_dtype
2727from torch .utils ._ordered_set import OrderedSet
2828from torch .utils ._sympy .functions import CeilDiv , FloorDiv , ModularIndexing
29- from torch .utils ._triton import has_triton_package , has_triton_stable_tma_api
29+ from torch .utils ._triton import has_triton_package , has_triton_stable_tma_api , get_triton_version
3030
3131from ...utils ._sympy .symbol import free_symbol_is_type , prefix_str , symbol_is_type , SymT
3232from ...utils ._sympy .value_ranges import ValueRanges
@@ -1315,9 +1315,9 @@ def tan(x):
13151315 @staticmethod
13161316 @maybe_upcast_float32 ()
13171317 def tanh (x ):
1318- # On ROCm, always use fast_tanhf
1319- # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+
1320- if torch . version . hip :
1318+ if torch . version . hip and get_triton_version () > ( 3 , 2 ):
1319+ # On ROCm, use fast_tanhf depending on Triton version
1320+ # Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+
13211321 return f"libdevice.fast_tanhf({ x } )"
13221322 else :
13231323 return f"libdevice.tanh({ x } )"
You can’t perform that action at this time.
0 commit comments