Skip to content

Commit 6110a92

Browse files
authored
Lock fast_tanh path behind USE_FAST_MATH control to resolve NaN issues (#2718)
NaN issues discovered after fast_tanh change https://ontrack-internal.amd.com/browse/SWDEV-560271 making this non-default path until further debugging
1 parent b8b81a9 commit 6110a92

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

torch/_inductor/codegen/triton.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from torch._prims_common import is_integer_dtype
2626
from torch.utils._ordered_set import OrderedSet
2727
from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing
28-
from torch.utils._triton import has_triton_package
28+
from torch.utils._triton import has_triton_package, get_triton_version
2929

3030
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
3131
from ...utils._sympy.value_ranges import ValueRanges
@@ -1217,7 +1217,13 @@ def tan(x):
12171217
@staticmethod
12181218
@maybe_upcast_float32()
12191219
def tanh(x):
1220-
return f"libdevice.fast_tanhf({x})"
1220+
if config.use_fast_math and torch.version.hip:
1221+
if get_triton_version() > (3, 4):
1222+
return f"libdevice.fast_tanhf({x})"
1223+
else:
1224+
return f"libdevice.tanh({x})"
1225+
else:
1226+
return f"libdevice.tanh({x})"
12211227

12221228
@staticmethod
12231229
@maybe_upcast_float32()

torch/utils/_triton.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,17 @@ def has_triton_tma_device():
6161
return False
6262

6363

64+
@functools.cache
65+
def get_triton_version(fallback: tuple[int, int] = (0, 0)) -> tuple[int, int]:
66+
try:
67+
import triton # noqa: F401
68+
69+
major, minor = tuple(int(v) for v in triton.__version__.split(".")[:2])
70+
return (major, minor)
71+
except ImportError:
72+
return fallback
73+
74+
6475
@functools.lru_cache(None)
6576
def has_triton() -> bool:
6677
if not has_triton_package():

0 commit comments

Comments
 (0)