@@ -106,29 +106,49 @@ class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
106106 assert (operands[0 ].getType ().getIntOrFloatBitWidth () == 32 );
107107 LLVM::FastmathFlagsAttr defaultFlags{};
108108
109- // Calculate 2*x
110- auto twoX = LLVM::FMulOp::create (
111- rewriter, loc, rewriter.getF32Type (), operands[0 ],
109+ // Numerically stable tanh implementation:
110+ // For positive x: tanh(x) = 1 - 2/(e^(2x) + 1)
111+ // For negative x: tanh(x) = -tanh(-x) = -(1 - 2/(e^(-2x) + 1))
112+ // = 2/(e^(-2x) + 1) - 1
113+ // This avoids overflow when e^(2x) becomes infinity for large x
114+
115+ // Get absolute value of x
116+ auto absX = LLVM::FAbsOp::create (rewriter, loc, rewriter.getF32Type (),
117+ operands[0 ]);
118+
119+ // Calculate 2*|x|
120+ auto twoAbsX = LLVM::FMulOp::create (
121+ rewriter, loc, rewriter.getF32Type (), absX,
112122 LLVM::createConstantF32 (loc, rewriter, 2.0 ), defaultFlags);
113123
114- // Calculate fast_expf (2*x) using the utility function
115- auto exp2X = createFastExpf (rewriter, loc, twoX ->getResult (0 ),
116- rewriter.getF32Type (), ftz);
124+ // Calculate e^ (2*|x|)
125+ auto exp2AbsX = createFastExpf (rewriter, loc, twoAbsX ->getResult (0 ),
126+ rewriter.getF32Type (), ftz);
117127
118- // Calculate exp2X - 1
119- auto exp2XMinus1 = LLVM::FSubOp ::create (
120- rewriter, loc, rewriter.getF32Type (), exp2X ->getResult (0 ),
128+ // Calculate e^(2*|x|) + 1
129+ auto exp2AbsXPlus1 = LLVM::FAddOp ::create (
130+ rewriter, loc, rewriter.getF32Type (), exp2AbsX ->getResult (0 ),
121131 LLVM::createConstantF32 (loc, rewriter, 1.0 ), defaultFlags);
122132
123- // Calculate exp2X + 1
124- auto exp2XPlus1 = LLVM::FAddOp::create (
125- rewriter, loc, rewriter.getF32Type (), exp2X->getResult (0 ),
126- LLVM::createConstantF32 (loc, rewriter, 1.0 ), defaultFlags);
127-
128- // Calculate tanh(X) = (exp2X - 1) / (exp2X + 1)
129- replacementOp = LLVM::FDivOp::create (
130- rewriter, loc, returnType, exp2XMinus1->getResult (0 ),
131- exp2XPlus1->getResult (0 ), defaultFlags);
133+ // Calculate 2 / (e^(2*|x|) + 1)
134+ auto two = LLVM::createConstantF32 (loc, rewriter, 2.0 );
135+ auto ratio =
136+ LLVM::FDivOp::create (rewriter, loc, rewriter.getF32Type (), two,
137+ exp2AbsXPlus1->getResult (0 ), defaultFlags);
138+
139+ // Calculate 1 - 2/(e^(2*|x|) + 1)
140+ auto one = LLVM::createConstantF32 (loc, rewriter, 1.0 );
141+ auto posResult =
142+ LLVM::FSubOp::create (rewriter, loc, rewriter.getF32Type (), one,
143+ ratio->getResult (0 ), defaultFlags);
144+
145+ // Apply the sign of the original input using copysign
146+ // tanh(x) = sign(x) * (1 - 2/(e^(2*|x|) + 1))
147+ const char *intrinsic = " llvm.copysign.f32" ;
148+ auto args =
149+ llvm::SmallVector<Value>{posResult->getResult (0 ), operands[0 ]};
150+ replacementOp = LLVM::createLLVMIntrinsicCallOp (rewriter, loc, intrinsic,
151+ returnType, args);
132152 }
133153
134154 if (replacementOp) {
0 commit comments