Skip to content

Commit 3f5eb50

Browse files
[AMD] reimplement fast_tanhf() to avoid overflow (#8551)
### The Problem with the Original Formula The original formula is: ``` tanh(x) = (e^(2x) - 1) / (e^(2x) + 1) ``` - Issue with large positive x: - When x = 20: e^(40) ≈ 2.4 × 10^17 → manageable - When x = 50: e^(100) ≈ 2.7 × 10^43 → overflow to infinity - Result: (∞ - 1)/(∞ + 1) = NaN x - For negative x: The formula actually works fine because e^(2x) → 0, giving (-1)/(1) = -1 ### The Numerically Stable Solution - For Positive x: Reformulation ``` tanh(x) = (e^(2x) - 1) / (e^(2x) + 1) = (e^(2x) + 1 - 2) / (e^(2x) + 1) = 1 - 2/(e^(2x) + 1) ``` - For Negative x: Using Symmetry ``` tanh(-x) = (e^(-2x) - 1) / (e^(-2x) + 1) = (2/(e^(-2x) + 1) - 1) = -1 × (1 - 2/(e^(2|x|) + 1)) ``` ### Unified formulation: ``` tanh(x) = sign(x) × (1 - 2/(e^(2|x|) + 1)) ```
1 parent 430f8b2 commit 3f5eb50

File tree

1 file changed

+38
-18
lines changed

1 file changed

+38
-18
lines changed

third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -106,29 +106,49 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
106106
assert(operands[0].getType().getIntOrFloatBitWidth() == 32);
107107
LLVM::FastmathFlagsAttr defaultFlags{};
108108

109-
// Calculate 2*x
110-
auto twoX = LLVM::FMulOp::create(
111-
rewriter, loc, rewriter.getF32Type(), operands[0],
109+
// Numerically stable tanh implementation:
110+
// For positive x: tanh(x) = 1 - 2/(e^(2x) + 1)
111+
// For negative x: tanh(x) = -tanh(-x) = -(1 - 2/(e^(-2x) + 1))
112+
// = 2/(e^(-2x) + 1) - 1
113+
// This avoids overflow when e^(2x) becomes infinity for large x
114+
115+
// Get absolute value of x
116+
auto absX = LLVM::FAbsOp::create(rewriter, loc, rewriter.getF32Type(),
117+
operands[0]);
118+
119+
// Calculate 2*|x|
120+
auto twoAbsX = LLVM::FMulOp::create(
121+
rewriter, loc, rewriter.getF32Type(), absX,
112122
LLVM::createConstantF32(loc, rewriter, 2.0), defaultFlags);
113123

114-
// Calculate fast_expf(2*x) using the utility function
115-
auto exp2X = createFastExpf(rewriter, loc, twoX->getResult(0),
116-
rewriter.getF32Type(), ftz);
124+
// Calculate e^(2*|x|)
125+
auto exp2AbsX = createFastExpf(rewriter, loc, twoAbsX->getResult(0),
126+
rewriter.getF32Type(), ftz);
117127

118-
// Calculate exp2X - 1
119-
auto exp2XMinus1 = LLVM::FSubOp::create(
120-
rewriter, loc, rewriter.getF32Type(), exp2X->getResult(0),
128+
// Calculate e^(2*|x|) + 1
129+
auto exp2AbsXPlus1 = LLVM::FAddOp::create(
130+
rewriter, loc, rewriter.getF32Type(), exp2AbsX->getResult(0),
121131
LLVM::createConstantF32(loc, rewriter, 1.0), defaultFlags);
122132

123-
// Calculate exp2X + 1
124-
auto exp2XPlus1 = LLVM::FAddOp::create(
125-
rewriter, loc, rewriter.getF32Type(), exp2X->getResult(0),
126-
LLVM::createConstantF32(loc, rewriter, 1.0), defaultFlags);
127-
128-
// Calculate tanh(X) = (exp2X - 1) / (exp2X + 1)
129-
replacementOp = LLVM::FDivOp::create(
130-
rewriter, loc, returnType, exp2XMinus1->getResult(0),
131-
exp2XPlus1->getResult(0), defaultFlags);
133+
// Calculate 2 / (e^(2*|x|) + 1)
134+
auto two = LLVM::createConstantF32(loc, rewriter, 2.0);
135+
auto ratio =
136+
LLVM::FDivOp::create(rewriter, loc, rewriter.getF32Type(), two,
137+
exp2AbsXPlus1->getResult(0), defaultFlags);
138+
139+
// Calculate 1 - 2/(e^(2*|x|) + 1)
140+
auto one = LLVM::createConstantF32(loc, rewriter, 1.0);
141+
auto posResult =
142+
LLVM::FSubOp::create(rewriter, loc, rewriter.getF32Type(), one,
143+
ratio->getResult(0), defaultFlags);
144+
145+
// Apply the sign of the original input using copysign
146+
// tanh(x) = sign(x) * (1 - 2/(e^(2*|x|) + 1))
147+
const char *intrinsic = "llvm.copysign.f32";
148+
auto args =
149+
llvm::SmallVector<Value>{posResult->getResult(0), operands[0]};
150+
replacementOp = LLVM::createLLVMIntrinsicCallOp(rewriter, loc, intrinsic,
151+
returnType, args);
132152
}
133153

134154
if (replacementOp) {

0 commit comments

Comments
 (0)