diff --git a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp index 6656be830989a..9282518191274 100644 --- a/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp +++ b/mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp @@ -520,29 +520,94 @@ struct ExpOpConversion : public OpConversionPattern { } }; +Value evaluatePolynomial(ImplicitLocOpBuilder &b, Value arg, + ArrayRef coefficients, + arith::FastMathFlagsAttr fmf) { + auto argType = mlir::cast(arg.getType()); + Value poly = + b.create(b.getFloatAttr(argType, coefficients[0])); + for (int i = 1; i < coefficients.size(); ++i) { + poly = b.create( + poly, arg, + b.create(b.getFloatAttr(argType, coefficients[i])), + fmf); + } + return poly; +} + struct Expm1OpConversion : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; + // e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i + // [handle inaccuracies when a and/or b are small] + // = ((e^a - 1) * cos(b) + cos(b) - 1) + e^a*sin(b)i + // = (expm1(a) * cos(b) + cosm1(b)) + e^a*sin(b)i LogicalResult matchAndRewrite(complex::Expm1Op op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto type = cast(adaptor.getComplex().getType()); - auto elementType = cast(type.getElementType()); + auto type = op.getType(); + auto elemType = mlir::cast(type.getElementType()); + arith::FastMathFlagsAttr fmf = op.getFastMathFlagsAttr(); + ImplicitLocOpBuilder b(op.getLoc(), rewriter); + Value real = b.create(adaptor.getComplex()); + Value imag = b.create(adaptor.getComplex()); - mlir::ImplicitLocOpBuilder b(op.getLoc(), rewriter); - Value exp = b.create(adaptor.getComplex(), fmf.getValue()); + Value zero = b.create(b.getFloatAttr(elemType, 0.0)); + Value one = b.create(b.getFloatAttr(elemType, 1.0)); - Value real = b.create(elementType, exp); - Value one = b.create(elementType, - b.getFloatAttr(elementType, 1)); - Value realMinusOne = b.create(real, one, fmf.getValue()); - Value imag = b.create(elementType, exp); + Value expm1Real = b.create(real, fmf); + Value expReal = b.create(expm1Real, one, fmf); + + Value sinImag = b.create(imag, fmf); + Value cosm1Imag = emitCosm1(imag, fmf, b); + Value cosImag = b.create(cosm1Imag, one, fmf); - rewriter.replaceOpWithNewOp(op, type, realMinusOne, - imag); + Value realResult = b.create( + b.create(expm1Real, cosImag, fmf), cosm1Imag, fmf); + + Value imagIsZero = b.create(arith::CmpFPredicate::OEQ, imag, + zero, fmf.getValue()); + Value imagResult = b.create( + imagIsZero, zero, b.create(expReal, sinImag, fmf)); + + rewriter.replaceOpWithNewOp(op, type, realResult, + imagResult); return success(); } + +private: + Value emitCosm1(Value arg, arith::FastMathFlagsAttr fmf, + ImplicitLocOpBuilder &b) const { + auto argType = mlir::cast(arg.getType()); + auto negHalf = b.create(b.getFloatAttr(argType, -0.5)); + auto negOne = b.create(b.getFloatAttr(argType, -1.0)); + + // Algorithm copied from cephes cosm1. + SmallVector kCoeffs{ + 4.7377507964246204691685E-14, -1.1470284843425359765671E-11, + 2.0876754287081521758361E-9, -2.7557319214999787979814E-7, + 2.4801587301570552304991E-5, -1.3888888888888872993737E-3, + 4.1666666666666666609054E-2, + }; + Value cos = b.create(arg, fmf); + Value forLargeArg = b.create(cos, negOne, fmf); + + Value argPow2 = b.create(arg, arg, fmf); + Value argPow4 = b.create(argPow2, argPow2, fmf); + Value poly = evaluatePolynomial(b, argPow2, kCoeffs, fmf); + + auto forSmallArg = + b.create(b.create(argPow4, poly, fmf), + b.create(negHalf, argPow2, fmf)); + + // (pi/4)^2 is approximately 0.61685 + Value piOver4Pow2 = + b.create(b.getFloatAttr(argType, 0.61685)); + Value cond = b.create(arith::CmpFPredicate::OGE, argPow2, + piOver4Pow2, fmf.getValue()); + return b.create(cond, forLargeArg, forSmallArg); + } }; struct LogOpConversion : public OpConversionPattern { diff --git a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir index d7767bda08435..3d73292e6b886 100644 --- a/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir +++ b/mlir/test/Conversion/ComplexToStandard/convert-to-standard.mlir @@ -221,26 +221,52 @@ func.func @complex_exp(%arg: complex) -> complex { // ----- -// CHECK-LABEL: func.func @complex_expm1( -// CHECK-SAME: %[[ARG:.*]]: complex) -> complex { +// CHECK-LABEL: func.func @complex_expm1( +// CHECK-SAME: %[[ARG:.*]]: complex) -> complex { func.func @complex_expm1(%arg: complex) -> complex { - %expm1 = complex.expm1 %arg: complex + %expm1 = complex.expm1 %arg fastmath : complex return %expm1 : complex } -// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex -// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex -// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] : f32 -// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] : f32 -// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] : f32 -// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] : f32 -// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] : f32 -// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex -// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex -// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] : f32 -// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex -// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex -// CHECK: return %[[RES]] : complex +// CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex +// CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex +// CHECK-DAG: %[[C0_F32:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[C1_F32:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK: %[[EXPM1:.*]] = math.expm1 %[[REAL]] fastmath : f32 +// CHECK: %[[VAL_6:.*]] = arith.addf %[[EXPM1]], %[[C1_F32]] fastmath : f32 +// CHECK: %[[VAL_7:.*]] = math.sin %[[IMAG]] fastmath : f32 +// CHECK: %[[VAL_8:.*]] = arith.constant -5.000000e-01 : f32 +// CHECK: %[[VAL_9:.*]] = arith.constant -1.000000e+00 : f32 +// CHECK: %[[VAL_10:.*]] = math.cos %[[IMAG]] fastmath : f32 +// CHECK: %[[VAL_11:.*]] = arith.addf %[[VAL_10]], %[[VAL_9]] fastmath : f32 +// CHECK: %[[VAL_12:.*]] = arith.mulf %[[IMAG]], %[[IMAG]] fastmath : f32 +// CHECK: %[[VAL_13:.*]] = arith.mulf %[[VAL_12]], %[[VAL_12]] fastmath : f32 +// CHECK-DAG: %[[COEF0:.*]] = arith.constant 4.73775072E-14 : f32 +// CHECK-DAG: %[[COEF1:.*]] = arith.constant -1.14702848E-11 : f32 +// CHECK: %[[FMA0:.*]] = math.fma %[[COEF0]], %[[VAL_12]], %[[COEF1]] fastmath : f32 +// CHECK: %[[COEF2:.*]] = arith.constant 2.08767537E-9 : f32 +// CHECK: %[[FMA1:.*]] = math.fma %[[FMA0]], %[[VAL_12]], %[[COEF2]] fastmath : f32 +// CHECK: %[[COEF3:.*]] = arith.constant -2.755732E-7 : f32 +// CHECK: %[[FMA2:.*]] = math.fma %[[FMA1]], %[[VAL_12]], %[[COEF3]] fastmath : f32 +// CHECK: %[[COEF4:.*]] = arith.constant 2.48015876E-5 : f32 +// CHECK: %[[FMA3:.*]] = math.fma %[[FMA2]], %[[VAL_12]], %[[COEF4]] fastmath : f32 +// CHECK: %[[COEF5:.*]] = arith.constant -0.00138888892 : f32 +// CHECK: %[[FMA4:.*]] = math.fma %[[FMA3]], %[[VAL_12]], %[[COEF5]] fastmath : f32 +// CHECK: %[[COEF6:.*]] = arith.constant 0.0416666679 : f32 +// CHECK: %[[FMA5:.*]] = math.fma %[[FMA4]], %[[VAL_12]], %[[COEF6]] fastmath : f32 +// CHECK-DAG: %[[VAL_27:.*]] = arith.mulf %[[VAL_13]], %[[FMA5]] fastmath : f32 +// CHECK-DAG: %[[VAL_28:.*]] = arith.mulf %[[VAL_8]], %[[VAL_12]] fastmath : f32 +// CHECK: %[[VAL_29:.*]] = arith.addf %[[VAL_27]], %[[VAL_28]] : f32 +// CHECK: %[[VAL_30:.*]] = arith.constant 6.168500e-01 : f32 +// CHECK: %[[VAL_31:.*]] = arith.cmpf oge, %[[VAL_12]], %[[VAL_30]] fastmath : f32 +// CHECK: %[[VAL_32:.*]] = arith.select %[[VAL_31]], %[[VAL_11]], %[[VAL_29]] : f32 +// CHECK: %[[VAL_33:.*]] = arith.addf %[[VAL_32]], %[[C1_F32]] fastmath : f32 +// CHECK: %[[VAL_34:.*]] = arith.mulf %[[EXPM1]], %[[VAL_33]] fastmath : f32 +// CHECK: %[[VAL_35:.*]] = arith.addf %[[VAL_34]], %[[VAL_32]] fastmath : f32 +// CHECK: %[[VAL_36:.*]] = arith.cmpf oeq, %[[IMAG]], %[[C0_F32]] fastmath : f32 +// CHECK: %[[VAL_37:.*]] = arith.mulf %[[VAL_6]], %[[VAL_7]] fastmath : f32 +// CHECK: %[[VAL_38:.*]] = arith.select %[[VAL_36]], %[[C0_F32]], %[[VAL_37]] : f32 +// CHECK: %[[RESULT:.*]] = complex.create %[[VAL_35]], %[[VAL_38]] : complex +// CHECK: return %[[RESULT]] : complex // ----- @@ -882,29 +908,6 @@ func.func @complex_exp_with_fmf(%arg: complex) -> complex { // ----- -// CHECK-LABEL: func.func @complex_expm1_with_fmf( -// CHECK-SAME: %[[ARG:.*]]: complex) -> complex { -func.func @complex_expm1_with_fmf(%arg: complex) -> complex { - %expm1 = complex.expm1 %arg fastmath : complex - return %expm1 : complex -} -// CHECK: %[[REAL_I:.*]] = complex.re %[[ARG]] : complex -// CHECK: %[[IMAG_I:.*]] = complex.im %[[ARG]] : complex -// CHECK: %[[EXP:.*]] = math.exp %[[REAL_I]] fastmath : f32 -// CHECK: %[[COS:.*]] = math.cos %[[IMAG_I]] fastmath : f32 -// CHECK: %[[RES_REAL:.*]] = arith.mulf %[[EXP]], %[[COS]] fastmath : f32 -// CHECK: %[[SIN:.*]] = math.sin %[[IMAG_I]] fastmath : f32 -// CHECK: %[[RES_IMAG:.*]] = arith.mulf %[[EXP]], %[[SIN]] fastmath : f32 -// CHECK: %[[RES_EXP:.*]] = complex.create %[[RES_REAL]], %[[RES_IMAG]] : complex -// CHECK: %[[REAL:.*]] = complex.re %[[RES_EXP]] : complex -// CHECK: %[[ONE:.*]] = arith.constant 1.000000e+00 : f32 -// CHECK: %[[REAL_M1:.*]] = arith.subf %[[REAL]], %[[ONE]] fastmath : f32 -// CHECK: %[[IMAG:.*]] = complex.im %[[RES_EXP]] : complex -// CHECK: %[[RES:.*]] = complex.create %[[REAL_M1]], %[[IMAG]] : complex -// CHECK: return %[[RES]] : complex - -// ----- - // CHECK-LABEL: func @complex_log_with_fmf // CHECK-SAME: %[[ARG:.*]]: complex func.func @complex_log_with_fmf(%arg: complex) -> complex { @@ -2020,4 +2023,4 @@ func.func @complex_angle_with_fmf(%arg: complex) -> f32 { // CHECK: %[[REAL:.*]] = complex.re %[[ARG]] : complex // CHECK: %[[IMAG:.*]] = complex.im %[[ARG]] : complex // CHECK: %[[RESULT:.*]] = math.atan2 %[[IMAG]], %[[REAL]] fastmath : f32 -// CHECK: return %[[RESULT]] : f32 \ No newline at end of file +// CHECK: return %[[RESULT]] : f32