Skip to content

Commit 9dc9120

Browse files
[release/2.7][ROCm][inductor] Improved fast_tanh code generation (#2802)
In the ROCm fork of PyTorch 2.7, Inductor currently has codegen support for fast_tanhf. However, it is currently guarded by `TORCHINDUCTOR_USE_FAST_MATH` environment variable due to some NaN issues in the original Triton implementation of fast_tanhf. Upstream Triton has an improved fast_tanhf where the NaN issues are now fixed. This upstream commit has been backported to ROCm fork of Triton (see code comments). Thus, I have removed the conditionalization on Triton versions as well. A bump in the Triton commit is also needed. Other notes: - In support of [SWDEV-560271](https://ontrack-internal.amd.com/browse/SWDEV-560271) - Triton 3.3 backport of upstream Triton commit ROCm/triton#902 - Similar to #2803, #2804 - Related to pytorch#162052
1 parent e311287 commit 9dc9120

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
9c7bc0a3d41407bff948b40cd0e9c793147e49bc
1+
80ed7f41e4b5d6e71651847e4725f4e7c2999a08

torch/_inductor/codegen/triton.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,11 +1217,10 @@ def tan(x):
12171217
@staticmethod
12181218
@maybe_upcast_float32()
12191219
def tanh(x):
1220-
if config.use_fast_math and torch.version.hip:
1221-
if get_triton_version() > (3, 4):
1222-
return f"libdevice.fast_tanhf({x})"
1223-
else:
1224-
return f"libdevice.tanh({x})"
1220+
if torch.version.hip and get_triton_version() > (3, 2):
1221+
# On ROCm, use fast_tanhf depending on Triton version
1222+
# Requires ROCm fork of Triton 3.3, 3.4, 3.5 or upstream Triton 3.6+
1223+
return f"libdevice.fast_tanhf({x})"
12251224
else:
12261225
return f"libdevice.tanh({x})"
12271226

0 commit comments

Comments
 (0)