@@ -313,25 +313,53 @@ struct DivOpConversion : public OpConversionPattern<complex::DivOp> {
313313struct ExpOpConversion : public OpConversionPattern <complex ::ExpOp> {
314314 using OpConversionPattern<complex ::ExpOp>::OpConversionPattern;
315315
316+ // exp(x+I*y) = exp(x)*(cos(y)+I*sin(y))
317+ // Handle special cases as StableHLO implementation does:
318+ // 1. When b == 0, set imag(exp(z)) = 0
319+ // 2. When exp(x) == inf, use exp(x/2)*(cos(y)+I*sin(y))*exp(x/2)
316320 LogicalResult
317321 matchAndRewrite (complex ::ExpOp op, OpAdaptor adaptor,
318322 ConversionPatternRewriter &rewriter) const override {
319323 auto loc = op.getLoc ();
320324 auto type = cast<ComplexType>(adaptor.getComplex ().getType ());
321- auto elementType = cast<FloatType>(type.getElementType ());
322- arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr ();
323-
324- Value real =
325- complex::ReOp::create (rewriter, loc, elementType, adaptor.getComplex ());
326- Value imag =
327- complex::ImOp::create (rewriter, loc, elementType, adaptor.getComplex ());
328- Value expReal = math::ExpOp::create (rewriter, loc, real, fmf.getValue ());
329- Value cosImag = math::CosOp::create (rewriter, loc, imag, fmf.getValue ());
325+ auto ET = cast<FloatType>(type.getElementType ());
326+ arith::FastMathFlags fmf = op.getFastMathFlagsAttr ().getValue ();
327+ const auto &floatSemantics = ET.getFloatSemantics ();
328+ ImplicitLocOpBuilder b (loc, rewriter);
329+
330+ Value x = complex::ReOp::create (b, ET, adaptor.getComplex ());
331+ Value y = complex::ImOp::create (b, ET, adaptor.getComplex ());
332+ Value zero = arith::ConstantOp::create (b, ET, b.getZeroAttr (ET));
333+ Value half = arith::ConstantOp::create (b, ET, b.getFloatAttr (ET, 0.5 ));
334+ Value inf = arith::ConstantOp::create (
335+ b, ET, b.getFloatAttr (ET, APFloat::getInf (floatSemantics)));
336+
337+ Value exp = math::ExpOp::create (b, x, fmf);
338+ Value xHalf = arith::MulFOp::create (b, x, half, fmf);
339+ Value expHalf = math::ExpOp::create (b, xHalf, fmf);
340+ Value cos = math::CosOp::create (b, y, fmf);
341+ Value sin = math::SinOp::create (b, y, fmf);
342+
343+ Value expIsInf =
344+ arith::CmpFOp::create (b, arith::CmpFPredicate::OEQ, exp, inf, fmf);
345+ Value yIsZero =
346+ arith::CmpFOp::create (b, arith::CmpFPredicate::OEQ, y, zero);
347+
348+ // Real path: select between exp(x)*cos(y) and exp(x/2)*cos(y)*exp(x/2)
349+ Value realNormal = arith::MulFOp::create (b, exp, cos, fmf);
350+ Value expHalfCos = arith::MulFOp::create (b, expHalf, cos, fmf);
351+ Value realOverflow = arith::MulFOp::create (b, expHalfCos, expHalf, fmf);
330352 Value resultReal =
331- arith::MulFOp::create (rewriter, loc, expReal, cosImag, fmf.getValue ());
332- Value sinImag = math::SinOp::create (rewriter, loc, imag, fmf.getValue ());
333- Value resultImag =
334- arith::MulFOp::create (rewriter, loc, expReal, sinImag, fmf.getValue ());
353+ arith::SelectOp::create (b, expIsInf, realOverflow, realNormal);
354+
355+ // Imaginary part: if y == 0 return 0 else select between exp(x)*sin(y) and
356+ // exp(x/2)*sin(y)*exp(x/2)
357+ Value imagNormal = arith::MulFOp::create (b, exp, sin, fmf);
358+ Value expHalfSin = arith::MulFOp::create (b, expHalf, sin, fmf);
359+ Value imagOverflow = arith::MulFOp::create (b, expHalfSin, expHalf, fmf);
360+ Value imagNonZero =
361+ arith::SelectOp::create (b, expIsInf, imagOverflow, imagNormal);
362+ Value resultImag = arith::SelectOp::create (b, yIsZero, zero, imagNonZero);
335363
336364 rewriter.replaceOpWithNewOp <complex ::CreateOp>(op, type, resultReal,
337365 resultImag);
0 commit comments