diff --git a/mlir/include/mlir/Dialect/Math/IR/MathOps.td b/mlir/include/mlir/Dialect/Math/IR/MathOps.td index 8a277320e2f91..16ce4e2366c76 100644 --- a/mlir/include/mlir/Dialect/Math/IR/MathOps.td +++ b/mlir/include/mlir/Dialect/Math/IR/MathOps.td @@ -560,6 +560,31 @@ def Math_ErfOp : Math_FloatUnaryOp<"erf"> { let hasFolder = 1; } +//===----------------------------------------------------------------------===// +// ErfcOp +//===----------------------------------------------------------------------===// + +def Math_ErfcOp : Math_FloatUnaryOp<"erfc"> { + let summary = "complementary error function of the specified value"; + let description = [{ + + The `erfc` operation computes the complementary error function, defined as + 1-erf(x). This function is part of libm and is needed for accuracy, since + simply calculating 1-erf(x) when x is close to 1 will give inaccurate results. + It takes one operand of floating point type (i.e., scalar, + tensor or vector) and returns one result of the same type. It has no + standard attributes. + + Example: + + ```mlir + // Scalar error function value. + %a = math.erfc %b : f64 + ``` + }]; + let hasFolder = 1; +} + //===----------------------------------------------------------------------===// // ExpOp diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h b/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h index b4ebc2f0f8fcd..ecfdb71817dff 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Approximation.h @@ -23,6 +23,14 @@ struct ErfPolynomialApproximation : public OpRewritePattern { PatternRewriter &rewriter) const final; }; +struct ErfcPolynomialApproximation : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(math::ErfcOp op, + PatternRewriter &rewriter) const final; +}; + } // namespace math } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h index ea7a556297a76..9adc1c6940a15 100644 --- a/mlir/include/mlir/Dialect/Math/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Math/Transforms/Passes.h @@ -47,6 +47,7 @@ struct MathPolynomialApproximationOptions { void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns); void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns); +void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns); // Adds patterns to convert to f32 around math functions for which `predicate` // returns true. diff --git a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp index c21ee9652b499..c4792884eb34a 100644 --- a/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp +++ b/mlir/lib/Conversion/MathToLibm/MathToLibm.cpp @@ -181,6 +181,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns, populatePatternsForOp(patterns, benefit, ctx, "cosf", "cos"); populatePatternsForOp(patterns, benefit, ctx, "coshf", "cosh"); populatePatternsForOp(patterns, benefit, ctx, "erff", "erf"); + populatePatternsForOp(patterns, benefit, ctx, "erfcf", "erfc"); populatePatternsForOp(patterns, benefit, ctx, "expf", "exp"); populatePatternsForOp(patterns, benefit, ctx, "exp2f", "exp2"); populatePatternsForOp(patterns, benefit, ctx, "expm1f", diff --git a/mlir/lib/Dialect/Math/IR/MathOps.cpp b/mlir/lib/Dialect/Math/IR/MathOps.cpp index 42e357c012739..9c4d88e2191ce 100644 --- a/mlir/lib/Dialect/Math/IR/MathOps.cpp +++ b/mlir/lib/Dialect/Math/IR/MathOps.cpp @@ -332,6 +332,24 @@ OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) { }); } +//===----------------------------------------------------------------------===// +// ErfcOp folder +//===----------------------------------------------------------------------===// + +OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) { + return constFoldUnaryOpConditional( + adaptor.getOperands(), [](const APFloat &a) -> std::optional { + switch (APFloat::SemanticsToEnum(a.getSemantics())) { + case APFloat::Semantics::S_IEEEdouble: + return APFloat(erfc(a.convertToDouble())); + case APFloat::Semantics::S_IEEEsingle: + return APFloat(erfcf(a.convertToFloat())); + default: + return {}; + } + }); +} + //===----------------------------------------------------------------------===// // IPowIOp folder //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp index 777427de9465c..167eebd786dba 100644 --- a/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp +++ b/mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp @@ -173,6 +173,10 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, // Helper functions to create constants. //----------------------------------------------------------------------------// +static Value boolCst(ImplicitLocOpBuilder &builder, bool value) { + return builder.create(builder.getBoolAttr(value)); +} + static Value floatCst(ImplicitLocOpBuilder &builder, float value, Type elementType) { assert((elementType.isF16() || elementType.isF32()) && @@ -1118,6 +1122,103 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op, return success(); } +// Approximates erfc(x) with p((x - 2) / (x + 2)), where p is a 9 degree +// polynomial.This approximation is based on the following stackoverflow post: +// https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf +// The stackoverflow post is in turn based on: +// M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of +// (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36, +// No. 153, January 1981, pp. 249-253. +// +// Maximum error: 2.65 ulps +LogicalResult +ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op, + PatternRewriter &rewriter) const { + Value x = op.getOperand(); + Type et = getElementTypeOrSelf(x); + + if (!et.isF32()) + return rewriter.notifyMatchFailure(op, "only f32 type is supported."); + std::optional shape = vectorShape(x); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + auto bcast = [&](Value value) -> Value { + return broadcast(builder, value, shape); + }; + + Value trueValue = bcast(boolCst(builder, true)); + Value zero = bcast(floatCst(builder, 0.0f, et)); + Value one = bcast(floatCst(builder, 1.0f, et)); + Value onehalf = bcast(floatCst(builder, 0.5f, et)); + Value neg4 = bcast(floatCst(builder, -4.0f, et)); + Value neg2 = bcast(floatCst(builder, -2.0f, et)); + Value pos2 = bcast(floatCst(builder, 2.0f, et)); + Value posInf = bcast(floatCst(builder, INFINITY, et)); + Value clampVal = bcast(floatCst(builder, 10.0546875f, et)); + + Value a = builder.create(x); + Value p = builder.create(a, pos2); + Value r = builder.create(one, p); + Value q = builder.create(neg4, r, one); + Value t = builder.create(builder.create(q, one), + neg2, a); + Value e = builder.create(builder.create(a), q, t); + q = builder.create(r, e, q); + + p = bcast(floatCst(builder, -0x1.a4a000p-12f, et)); // -4.01139259e-4 + Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3 + p = builder.create(p, q, c1); + Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); // 1.31355342e-3 + p = builder.create(p, q, c2); + Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3 + p = builder.create(p, q, c3); + Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3 + p = builder.create(p, q, c4); + Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2 + p = builder.create(p, q, c5); + Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); // 1.64055392e-1 + p = builder.create(p, q, c6); + Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1 + p = builder.create(p, q, c7); + Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2 + p = builder.create(p, q, c8); + Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1 + p = builder.create(p, q, c9); + + Value d = builder.create(pos2, a, one); + r = builder.create(one, d); + q = builder.create(p, r, r); + Value negfa = builder.create(a); + Value fmaqah = builder.create(q, negfa, onehalf); + Value psubq = builder.create(p, q); + e = builder.create(fmaqah, pos2, psubq); + r = builder.create(e, r, q); + + Value s = builder.create(a, a); + e = builder.create(builder.create(s)); + + t = builder.create(builder.create(a), a, s); + r = builder.create( + r, e, + builder.create(builder.create(r, e), t)); + + Value isNotLessThanInf = builder.create( + builder.create(arith::CmpFPredicate::OLT, a, posInf), + trueValue); + r = builder.create(isNotLessThanInf, + builder.create(x, x), r); + Value isGreaterThanClamp = + builder.create(arith::CmpFPredicate::OGT, a, clampVal); + r = builder.create(isGreaterThanClamp, zero, r); + + Value isNegative = + builder.create(arith::CmpFPredicate::OLT, x, zero); + r = builder.create( + isNegative, builder.create(pos2, r), r); + + rewriter.replaceOp(op, r); + return success(); +} //----------------------------------------------------------------------------// // Exp approximation. //----------------------------------------------------------------------------// @@ -1667,6 +1768,11 @@ void mlir::populatePolynomialApproximateErfPattern( patterns.add(patterns.getContext()); } +void mlir::populatePolynomialApproximateErfcPattern( + RewritePatternSet &patterns) { + patterns.add(patterns.getContext()); +} + template static void populateMathF32ExpansionPattern(RewritePatternSet &patterns, @@ -1690,6 +1796,7 @@ void mlir::populateMathF32ExpansionPatterns( populateMathF32ExpansionPattern(patterns, predicate); populateMathF32ExpansionPattern(patterns, predicate); populateMathF32ExpansionPattern(patterns, predicate); + populateMathF32ExpansionPattern(patterns, predicate); populateMathF32ExpansionPattern(patterns, predicate); populateMathF32ExpansionPattern(patterns, predicate); populateMathF32ExpansionPattern(patterns, predicate); @@ -1734,6 +1841,9 @@ void mlir::populateMathPolynomialApproximationPatterns( CosOp, SinAndCosApproximation>(patterns, predicate); populateMathPolynomialApproximationPattern( patterns, predicate); + populateMathPolynomialApproximationPattern( + patterns, predicate); populateMathPolynomialApproximationPattern( patterns, predicate); populateMathPolynomialApproximationPattern( @@ -1760,9 +1870,10 @@ void mlir::populateMathPolynomialApproximationPatterns( {math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(), math::TanhOp::getOperationName(), math::LogOp::getOperationName(), math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(), - math::ErfOp::getOperationName(), math::ExpOp::getOperationName(), - math::ExpM1Op::getOperationName(), math::CbrtOp::getOperationName(), - math::SinOp::getOperationName(), math::CosOp::getOperationName()}, + math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(), + math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(), + math::CbrtOp::getOperationName(), math::SinOp::getOperationName(), + math::CosOp::getOperationName()}, name); }); @@ -1774,8 +1885,9 @@ void mlir::populateMathPolynomialApproximationPatterns( math::TanhOp::getOperationName(), math::LogOp::getOperationName(), math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(), - math::AsinOp::getOperationName(), math::AcosOp::getOperationName(), - math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(), + math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(), + math::AcosOp::getOperationName(), math::ExpOp::getOperationName(), + math::ExpM1Op::getOperationName(), math::CbrtOp::getOperationName(), math::SinOp::getOperationName(), math::CosOp::getOperationName()}, name); diff --git a/mlir/test/Dialect/Math/polynomial-approximation.mlir b/mlir/test/Dialect/Math/polynomial-approximation.mlir index 81d071e6bbba3..bf7c4134af12e 100644 --- a/mlir/test/Dialect/Math/polynomial-approximation.mlir +++ b/mlir/test/Dialect/Math/polynomial-approximation.mlir @@ -81,6 +81,116 @@ func.func @erf_scalar(%arg0: f32) -> f32 { return %0 : f32 } +// CHECK-LABEL: func @erfc_scalar( +// CHECK-SAME: %[[val_arg0:.*]]: f32) -> f32 { +// CHECK-DAG: %[[c127_i32:.*]] = arith.constant 127 : i32 +// CHECK-DAG: %[[c23_i32:.*]] = arith.constant 23 : i32 +// CHECK-DAG: %[[cst:.*]] = arith.constant 1.270000e+02 : f32 +// CHECK-DAG: %[[cst_0:.*]] = arith.constant -1.270000e+02 : f32 +// CHECK-DAG: %[[cst_1:.*]] = arith.constant 8.880000e+01 : f32 +// CHECK-DAG: %[[cst_2:.*]] = arith.constant -8.780000e+01 : f32 +// CHECK-DAG: %[[cst_3:.*]] = arith.constant 0.166666657 : f32 +// CHECK-DAG: %[[cst_4:.*]] = arith.constant 0.0416657962 : f32 +// CHECK-DAG: %[[cst_5:.*]] = arith.constant 0.00833345205 : f32 +// CHECK-DAG: %[[cst_6:.*]] = arith.constant 0.00139819994 : f32 +// CHECK-DAG: %[[cst_7:.*]] = arith.constant 1.98756912E-4 : f32 +// CHECK-DAG: %[[cst_8:.*]] = arith.constant 2.12194442E-4 : f32 +// CHECK-DAG: %[[cst_9:.*]] = arith.constant -0.693359375 : f32 +// CHECK-DAG: %[[cst_10:.*]] = arith.constant 1.44269502 : f32 +// CHECK-DAG: %[[cst_11:.*]] = arith.constant 0.276978403 : f32 +// CHECK-DAG: %[[cst_12:.*]] = arith.constant -0.0927639827 : f32 +// CHECK-DAG: %[[cst_13:.*]] = arith.constant -0.166031361 : f32 +// CHECK-DAG: %[[cst_14:.*]] = arith.constant 0.164055392 : f32 +// CHECK-DAG: %[[cst_15:.*]] = arith.constant -0.0542046614 : f32 +// CHECK-DAG: %[[cst_16:.*]] = arith.constant -8.059920e-03 : f32 +// CHECK-DAG: %[[cst_17:.*]] = arith.constant 0.00863227434 : f32 +// CHECK-DAG: %[[cst_18:.*]] = arith.constant 0.00131355342 : f32 +// CHECK-DAG: %[[cst_19:.*]] = arith.constant -0.0012307521 : f32 +// CHECK-DAG: %[[cst_20:.*]] = arith.constant -4.01139259E-4 : f32 +// CHECK-DAG: %[[cst_true:.*]] = arith.constant true +// CHECK-DAG: %[[cst_21:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[cst_22:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[cst_23:.*]] = arith.constant 5.000000e-01 : f32 +// CHECK-DAG: %[[cst_24:.*]] = arith.constant -4.000000e+00 : f32 +// CHECK-DAG: %[[cst_25:.*]] = arith.constant -2.000000e+00 : f32 +// CHECK-DAG: %[[cst_26:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[cst_27:.*]] = arith.constant 0x7F800000 : f32 +// CHECK-DAG: %[[cst_28:.*]] = arith.constant 10.0546875 : f32 +// CHECK: %[[val_2:.*]] = math.absf %[[val_arg0]] : f32 +// CHECK-NEXT: %[[val_3:.*]] = arith.addf %[[val_2]], %[[cst_26]] : f32 +// CHECK-NEXT: %[[val_4:.*]] = arith.divf %[[cst_22]], %[[val_3]] : f32 +// CHECK-NEXT: %[[val_5:.*]] = math.fma %[[cst_24]], %[[val_4]], %[[cst_22]] : f32 +// CHECK-NEXT: %[[val_6:.*]] = arith.addf %[[val_5]], %[[cst_22]] : f32 +// CHECK-NEXT: %[[val_7:.*]] = math.fma %[[val_6]], %[[cst_25]], %[[val_2]] : f32 +// CHECK-NEXT: %[[val_8:.*]] = arith.negf %[[val_2]] : f32 +// CHECK-NEXT: %[[val_9:.*]] = math.fma %[[val_8]], %[[val_5]], %[[val_7]] : f32 +// CHECK-NEXT: %[[val_10:.*]] = math.fma %[[val_4]], %[[val_9]], %[[val_5]] : f32 +// CHECK-NEXT: %[[val_11:.*]] = math.fma %[[cst_20]], %[[val_10]], %[[cst_19]] : f32 +// CHECK-NEXT: %[[val_12:.*]] = math.fma %[[val_11]], %[[val_10]], %[[cst_18]] : f32 +// CHECK-NEXT: %[[val_13:.*]] = math.fma %[[val_12]], %[[val_10]], %[[cst_17]] : f32 +// CHECK-NEXT: %[[val_14:.*]] = math.fma %[[val_13]], %[[val_10]], %[[cst_16]] : f32 +// CHECK-NEXT: %[[val_15:.*]] = math.fma %[[val_14]], %[[val_10]], %[[cst_15]] : f32 +// CHECK-NEXT: %[[val_16:.*]] = math.fma %[[val_15]], %[[val_10]], %[[cst_14]] : f32 +// CHECK-NEXT: %[[val_17:.*]] = math.fma %[[val_16]], %[[val_10]], %[[cst_13]] : f32 +// CHECK-NEXT: %[[val_18:.*]] = math.fma %[[val_17]], %[[val_10]], %[[cst_12]] : f32 +// CHECK-NEXT: %[[val_19:.*]] = math.fma %[[val_18]], %[[val_10]], %[[cst_11]] : f32 +// CHECK-NEXT: %[[val_20:.*]] = math.fma %[[cst_26]], %[[val_2]], %[[cst_22]] : f32 +// CHECK-NEXT: %[[val_21:.*]] = arith.divf %[[cst_22]], %[[val_20]] : f32 +// CHECK-NEXT: %[[val_22:.*]] = math.fma %[[val_19]], %[[val_21]], %[[val_21]] : f32 +// CHECK-NEXT: %[[val_23:.*]] = arith.negf %[[val_2]] : f32 +// CHECK-NEXT: %[[val_24:.*]] = math.fma %[[val_22]], %[[val_23]], %[[cst_23]] : f32 +// CHECK-NEXT: %[[val_25:.*]] = arith.subf %[[val_19]], %[[val_22]] : f32 +// CHECK-NEXT: %[[val_26:.*]] = math.fma %[[val_24]], %[[cst_26]], %[[val_25]] : f32 +// CHECK-NEXT: %[[val_27:.*]] = math.fma %[[val_26]], %[[val_21]], %[[val_22]] : f32 +// CHECK-NEXT: %[[val_28:.*]] = arith.mulf %[[val_2]], %[[val_2]] : f32 +// CHECK-NEXT: %[[val_29:.*]] = arith.negf %[[val_28]] : f32 +// CHECK-NEXT: %[[val_30:.*]] = arith.cmpf uge, %[[val_29]], %[[cst_2]] : f32 +// CHECK-NEXT: %[[val_31:.*]] = arith.select %[[val_30]], %[[val_29]], %[[cst_2]] : f32 +// CHECK-NEXT: %[[val_32:.*]] = arith.cmpf ule, %[[val_31]], %[[cst_1]] : f32 +// CHECK-NEXT: %[[val_33:.*]] = arith.select %[[val_32]], %[[val_31]], %[[cst_1]] : f32 +// CHECK-NEXT: %[[val_34:.*]] = math.fma %[[val_33]], %[[cst_10]], %[[cst_23]] : f32 +// CHECK-NEXT: %[[val_35:.*]] = math.floor %[[val_34]] : f32 +// CHECK-NEXT: %[[val_36:.*]] = arith.cmpf uge, %[[val_35]], %[[cst_0]] : f32 +// CHECK-NEXT: %[[val_37:.*]] = arith.select %[[val_36]], %[[val_35]], %[[cst_0]] : f32 +// CHECK-NEXT: %[[val_38:.*]] = arith.cmpf ule, %[[val_37]], %[[cst]] : f32 +// CHECK-NEXT: %[[val_39:.*]] = arith.select %[[val_38]], %[[val_37]], %[[cst]] : f32 +// CHECK-NEXT: %[[val_40:.*]] = math.fma %[[cst_9]], %[[val_39]], %[[val_33]] : f32 +// CHECK-NEXT: %[[val_41:.*]] = math.fma %[[cst_8]], %[[val_39]], %[[val_40]] : f32 +// CHECK-NEXT: %[[val_42:.*]] = math.fma %[[val_41]], %[[cst_7]], %[[cst_6]] : f32 +// CHECK-NEXT: %[[val_43:.*]] = math.fma %[[val_42]], %[[val_41]], %[[cst_5]] : f32 +// CHECK-NEXT: %[[val_44:.*]] = math.fma %[[val_43]], %[[val_41]], %[[cst_4]] : f32 +// CHECK-NEXT: %[[val_45:.*]] = math.fma %[[val_44]], %[[val_41]], %[[cst_3]] : f32 +// CHECK-NEXT: %[[val_46:.*]] = math.fma %[[val_45]], %[[val_41]], %[[cst_23]] : f32 +// CHECK-NEXT: %[[val_47:.*]] = arith.mulf %[[val_41]], %[[val_41]] : f32 +// CHECK-NEXT: %[[val_48:.*]] = math.fma %[[val_46]], %[[val_47]], %[[val_41]] : f32 +// CHECK-NEXT: %[[val_49:.*]] = arith.addf %[[val_48]], %[[cst_22]] : f32 +// CHECK-NEXT: %[[val_50:.*]] = arith.fptosi %[[val_39]] : f32 to i32 +// CHECK-NEXT: %[[val_51:.*]] = arith.addi %[[val_50]], %[[c127_i32]] : i32 +// CHECK-NEXT: %[[val_52:.*]] = arith.shli %[[val_51]], %[[c23_i32]] : i32 +// CHECK-NEXT: %[[val_53:.*]] = arith.bitcast %[[val_52]] : i32 to f32 +// CHECK-NEXT: %[[val_54:.*]] = arith.mulf %[[val_49]], %[[val_53]] : f32 +// CHECK-NEXT: %[[val_55:.*]] = arith.negf %[[val_2]] : f32 +// CHECK-NEXT: %[[val_56:.*]] = math.fma %[[val_55]], %[[val_2]], %[[val_28]] : f32 +// CHECK-NEXT: %[[val_57:.*]] = arith.mulf %[[val_27]], %[[val_54]] : f32 +// CHECK-NEXT: %[[val_58:.*]] = arith.mulf %[[val_57]], %[[val_56]] : f32 +// CHECK-NEXT: %[[val_59:.*]] = math.fma %[[val_27]], %[[val_54]], %[[val_58]] : f32 +// CHECK-NEXT: %[[val_60:.*]] = arith.cmpf olt, %[[val_2]], %[[cst_27]] : f32 +// CHECK-NEXT: %[[val_61:.*]] = arith.xori %[[val_60]], %[[cst_true]] : i1 +// CHECK-NEXT: %[[val_62:.*]] = arith.addf %[[val_arg0]], %[[val_arg0]] : f32 +// CHECK-NEXT: %[[val_63:.*]] = arith.select %[[val_61]], %[[val_62]], %[[val_59]] : f32 +// CHECK-NEXT: %[[val_64:.*]] = arith.cmpf ogt, %[[val_2]], %[[cst_28]] : f32 +// CHECK-NEXT: %[[val_65:.*]] = arith.select %[[val_64]], %[[cst_21]], %[[val_63]] : f32 +// CHECK-NEXT: %[[val_66:.*]] = arith.cmpf olt, %[[val_arg0]], %[[cst_21]] : f32 +// CHECK-NEXT: %[[val_67:.*]] = arith.subf %[[cst_26]], %[[val_65]] : f32 +// CHECK-NEXT: %[[val_68:.*]] = arith.select %[[val_66]], %[[val_67]], %[[val_65]] : f32 +// CHECK-NEXT: return %[[val_68]] : f32 +// CHECK-NEXT: } + +func.func @erfc_scalar(%arg0: f32) -> f32 { + %0 = math.erfc %arg0 : f32 + return %0 : f32 +} + // CHECK-LABEL: func @erf_vector( // CHECK-SAME: %[[arg0:.*]]: vector<8xf32>) -> vector<8xf32> { // CHECK: %[[zero:.*]] = arith.constant dense<0.000000e+00> : vector<8xf32> diff --git a/mlir/test/mlir-runner/math-polynomial-approx.mlir b/mlir/test/mlir-runner/math-polynomial-approx.mlir index 148ef25cead62..6ed03916f1e15 100644 --- a/mlir/test/mlir-runner/math-polynomial-approx.mlir +++ b/mlir/test/mlir-runner/math-polynomial-approx.mlir @@ -273,6 +273,77 @@ func.func @erf() { return } +// -------------------------------------------------------------------------- // +// Erfc. +// -------------------------------------------------------------------------- // +func.func @erfc_f32(%a : f32) { + %r = math.erfc %a : f32 + vector.print %r : f32 + return +} + +func.func @erfc_4xf32(%a : vector<4xf32>) { + %r = math.erfc %a : vector<4xf32> + vector.print %r : vector<4xf32> + return +} + +func.func @erfc() { + // CHECK: 1.00027 + %val1 = arith.constant -2.431864e-4 : f32 + call @erfc_f32(%val1) : (f32) -> () + + // CHECK: 0.257905 + %val2 = arith.constant 0.79999 : f32 + call @erfc_f32(%val2) : (f32) -> () + + // CHECK: 0.257899 + %val3 = arith.constant 0.8 : f32 + call @erfc_f32(%val3) : (f32) -> () + + // CHECK: 0.00467794 + %val4 = arith.constant 1.99999 : f32 + call @erfc_f32(%val4) : (f32) -> () + + // CHECK: 0.00467774 + %val5 = arith.constant 2.0 : f32 + call @erfc_f32(%val5) : (f32) -> () + + // CHECK: 1.13736e-07 + %val6 = arith.constant 3.74999 : f32 + call @erfc_f32(%val6) : (f32) -> () + + // CHECK: 1.13727e-07 + %val7 = arith.constant 3.75 : f32 + call @erfc_f32(%val7) : (f32) -> () + + // CHECK: 2 + %negativeInf = arith.constant 0xff800000 : f32 + call @erfc_f32(%negativeInf) : (f32) -> () + + // CHECK: 2, 2, 1.91376, 1.73145 + %vecVals1 = arith.constant dense<[-3.4028235e+38, -4.54318, -1.2130899, -7.8234202e-01]> : vector<4xf32> + call @erfc_4xf32(%vecVals1) : (vector<4xf32>) -> () + + // CHECK: 1, 1, 1, 0.878681 + %vecVals2 = arith.constant dense<[-1.1754944e-38, 0.0, 1.1754944e-38, 1.0793410e-01]> : vector<4xf32> + call @erfc_4xf32(%vecVals2) : (vector<4xf32>) -> () + + // CHECK: 0.0805235, 0.000931045, 6.40418e-08, 0 + %vecVals3 = arith.constant dense<[1.23578, 2.34093, 3.82342, 3.4028235e+38]> : vector<4xf32> + call @erfc_4xf32(%vecVals3) : (vector<4xf32>) -> () + + // CHECK: 0 + %inf = arith.constant 0x7f800000 : f32 + call @erfc_f32(%inf) : (f32) -> () + + // CHECK: nan + %nan = arith.constant 0x7fc00000 : f32 + call @erfc_f32(%nan) : (f32) -> () + + return +} + // -------------------------------------------------------------------------- // // Exp. // -------------------------------------------------------------------------- // @@ -772,6 +843,7 @@ func.func @main() { call @log2(): () -> () call @log1p(): () -> () call @erf(): () -> () + call @erfc(): () -> () call @exp(): () -> () call @expm1(): () -> () call @sin(): () -> () diff --git a/mlir/utils/vim/syntax/mlir.vim b/mlir/utils/vim/syntax/mlir.vim index 7989032eada88..070d81658ca3d 100644 --- a/mlir/utils/vim/syntax/mlir.vim +++ b/mlir/utils/vim/syntax/mlir.vim @@ -44,6 +44,7 @@ syn keyword mlirOps view " Math ops. syn match mlirOps /\/ +syn match mlirOps /\/ " Affine ops. syn match mlirOps /\/