@@ -173,6 +173,10 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
173173// Helper functions to create constants.
174174// ----------------------------------------------------------------------------//
175175
176+ static Value boolCst (ImplicitLocOpBuilder &builder, bool value) {
177+ return builder.create <arith::ConstantOp>(builder.getBoolAttr (value));
178+ }
179+
176180static Value floatCst (ImplicitLocOpBuilder &builder, float value,
177181 Type elementType) {
178182 assert ((elementType.isF16 () || elementType.isF32 ()) &&
@@ -1118,12 +1122,102 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
11181122 return success ();
11191123}
11201124
1125+ // Approximates erfc(x) with
1126+ LogicalResult
1127+ ErfcPolynomialApproximation::matchAndRewrite (math::ErfcOp op,
1128+ PatternRewriter &rewriter) const {
1129+ Value x = op.getOperand ();
1130+ Type et = getElementTypeOrSelf (x);
1131+
1132+ if (!et.isF32 ())
1133+ return rewriter.notifyMatchFailure (op, " only f32 type is supported." );
1134+ std::optional<VectorShape> shape = vectorShape (x);
1135+
1136+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1137+ auto bcast = [&](Value value) -> Value {
1138+ return broadcast (builder, value, shape);
1139+ };
1140+
1141+ Value trueValue = bcast (boolCst (builder, true ));
1142+ Value zero = bcast (floatCst (builder, 0 .0f , et));
1143+ Value one = bcast (floatCst (builder, 1 .0f , et));
1144+ Value onehalf = bcast (floatCst (builder, 0 .5f , et));
1145+ Value neg4 = bcast (floatCst (builder, -4 .0f , et));
1146+ Value neg2 = bcast (floatCst (builder, -2 .0f , et));
1147+ Value pos2 = bcast (floatCst (builder, 2 .0f , et));
1148+ Value posInf = bcast (f32FromBits (builder, 0x7f800000u ));
1149+ Value clampVal = bcast (floatCst (builder, 10 .0546875f , et));
1150+
1151+ // Get abs(x)
1152+ Value isNegativeArg =
1153+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
1154+ Value negArg = builder.create <arith::NegFOp>(x);
1155+ Value a = builder.create <arith::SelectOp>(isNegativeArg, negArg, x);
1156+ Value p = builder.create <arith::AddFOp>(a, pos2);
1157+ Value r = builder.create <arith::DivFOp>(one, p);
1158+ Value q = builder.create <math::FmaOp>(neg4, r, one);
1159+ Value t = builder.create <math::FmaOp>(builder.create <arith::AddFOp>(q, one),
1160+ neg2, a);
1161+ Value e = builder.create <math::FmaOp>(builder.create <arith::NegFOp>(a), q, t);
1162+ q = builder.create <math::FmaOp>(r, e, q);
1163+
1164+ p = bcast (floatCst (builder, -0x1 .a4a000p -12f , et)); // -4.01139259e-4
1165+ Value c1 = bcast (floatCst (builder, -0x1 .42a260p-10f , et)); // -1.23075210e-3
1166+ p = builder.create <math::FmaOp>(p, q, c1);
1167+ Value c2 = bcast (floatCst (builder, 0x1 .585714p-10f , et)); // 1.31355342e-3
1168+ p = builder.create <math::FmaOp>(p, q, c2);
1169+ Value c3 = bcast (floatCst (builder, 0x1 .1adcc4p-07f , et)); // 8.63227434e-3
1170+ p = builder.create <math::FmaOp>(p, q, c3);
1171+ Value c4 = bcast (floatCst (builder, -0x1 .081b82p-07f , et)); // -8.05991981e-3
1172+ p = builder.create <math::FmaOp>(p, q, c4);
1173+ Value c5 = bcast (floatCst (builder, -0x1 .bc0b6ap -05f , et)); // -5.42046614e-2
1174+ p = builder.create <math::FmaOp>(p, q, c5);
1175+ Value c6 = bcast (floatCst (builder, 0x1 .4ffc46p-03f , et)); // 1.64055392e-1
1176+ p = builder.create <math::FmaOp>(p, q, c6);
1177+ Value c7 = bcast (floatCst (builder, -0x1 .540840p-03f , et)); // -1.66031361e-1
1178+ p = builder.create <math::FmaOp>(p, q, c7);
1179+ Value c8 = bcast (floatCst (builder, -0x1 .7bf616p-04f , et)); // -9.27639827e-2
1180+ p = builder.create <math::FmaOp>(p, q, c8);
1181+ Value c9 = bcast (floatCst (builder, 0x1 .1ba03ap-02f , et)); // 2.76978403e-1
1182+ p = builder.create <math::FmaOp>(p, q, c9);
1183+
1184+ Value d = builder.create <math::FmaOp>(pos2, a, one);
1185+ r = builder.create <arith::DivFOp>(one, d);
1186+ q = builder.create <math::FmaOp>(p, r, r);
1187+ e = builder.create <math::FmaOp>(
1188+ builder.create <math::FmaOp>(q, builder.create <arith::NegFOp>(a), onehalf),
1189+ pos2, builder.create <arith::SubFOp>(p, q));
1190+ r = builder.create <math::FmaOp>(e, r, q);
1191+
1192+ Value s = builder.create <arith::MulFOp>(a, a);
1193+ e = builder.create <math::ExpOp>(builder.create <arith::NegFOp>(s));
1194+
1195+ t = builder.create <math::FmaOp>(builder.create <arith::NegFOp>(a), a, s);
1196+ r = builder.create <math::FmaOp>(
1197+ r, e,
1198+ builder.create <arith::MulFOp>(builder.create <arith::MulFOp>(r, e), t));
1199+
1200+ Value isNotLessThanInf = builder.create <arith::XOrIOp>(
1201+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
1202+ trueValue);
1203+ r = builder.create <arith::SelectOp>(isNotLessThanInf,
1204+ builder.create <arith::AddFOp>(x, x), r);
1205+ Value isGreaterThanClamp =
1206+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
1207+ r = builder.create <arith::SelectOp>(isGreaterThanClamp, zero, r);
1208+
1209+ Value isNegative =
1210+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
1211+ r = builder.create <arith::SelectOp>(
1212+ isNegative, builder.create <arith::SubFOp>(pos2, r), r);
1213+
1214+ rewriter.replaceOp (op, r);
1215+ return success ();
1216+ }
11211217// ----------------------------------------------------------------------------//
11221218// Exp approximation.
11231219// ----------------------------------------------------------------------------//
1124-
11251220namespace {
1126-
11271221Value clampWithNormals (ImplicitLocOpBuilder &builder,
11281222 const std::optional<VectorShape> shape, Value value,
11291223 float lowerBound, float upperBound) {
@@ -1667,6 +1761,11 @@ void mlir::populatePolynomialApproximateErfPattern(
16671761 patterns.add <ErfPolynomialApproximation>(patterns.getContext ());
16681762}
16691763
1764+ void mlir::populatePolynomialApproximateErfcPattern (
1765+ RewritePatternSet &patterns) {
1766+ patterns.add <ErfcPolynomialApproximation>(patterns.getContext ());
1767+
1768+
16701769template <typename OpType>
16711770static void
16721771populateMathF32ExpansionPattern (RewritePatternSet &patterns,
@@ -1690,6 +1789,7 @@ void mlir::populateMathF32ExpansionPatterns(
16901789 populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
16911790 populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
16921791 populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
1792+ populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
16931793 populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
16941794 populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
16951795 populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
@@ -1734,6 +1834,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
17341834 CosOp, SinAndCosApproximation<false , math::CosOp>>(patterns, predicate);
17351835 populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
17361836 patterns, predicate);
1837+ populateMathPolynomialApproximationPattern<ErfcOp,
1838+ ErfcPolynomialApproximation>(
1839+ patterns, predicate);
17371840 populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
17381841 patterns, predicate);
17391842 populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
@@ -1753,16 +1856,17 @@ void mlir::populateMathPolynomialApproximationPatterns(
17531856}
17541857
17551858void mlir::populateMathPolynomialApproximationPatterns (
1756- RewritePatternSet &patterns,
1859+ RewritePatternSet & patterns,
17571860 const MathPolynomialApproximationOptions &options) {
17581861 mlir::populateMathF32ExpansionPatterns (patterns, [](StringRef name) -> bool {
17591862 return llvm::is_contained (
17601863 {math::AtanOp::getOperationName (), math::Atan2Op::getOperationName (),
17611864 math::TanhOp::getOperationName (), math::LogOp::getOperationName (),
17621865 math::Log2Op::getOperationName (), math::Log1pOp::getOperationName (),
1763- math::ErfOp::getOperationName (), math::ExpOp::getOperationName (),
1764- math::ExpM1Op::getOperationName (), math::CbrtOp::getOperationName (),
1765- math::SinOp::getOperationName (), math::CosOp::getOperationName ()},
1866+ math::ErfOp::getOperationName (), math::ErfcOp::getOperationName (),
1867+ math::ExpOp::getOperationName (), math::ExpM1Op::getOperationName (),
1868+ math::CbrtOp::getOperationName (), math::SinOp::getOperationName (),
1869+ math::CosOp::getOperationName ()},
17661870 name);
17671871 });
17681872
@@ -1774,8 +1878,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
17741878 math::TanhOp::getOperationName (), math::LogOp::getOperationName (),
17751879 math::Log2Op::getOperationName (),
17761880 math::Log1pOp::getOperationName (), math::ErfOp::getOperationName (),
1777- math::AsinOp::getOperationName (), math::AcosOp::getOperationName (),
1778- math::ExpOp::getOperationName (), math::ExpM1Op::getOperationName (),
1881+ math::ErcfOp::getOperationName (), math::AsinOp::getOperationName (),
1882+ math::AcosOp::getOperationName (), math::ExpOp::getOperationName (),
1883+ math::ExpM1Op::getOperationName (),
17791884 math::CbrtOp::getOperationName (), math::SinOp::getOperationName (),
17801885 math::CosOp::getOperationName ()},
17811886 name);
0 commit comments