@@ -173,6 +173,10 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
173
173
// Helper functions to create constants.
174
174
// ----------------------------------------------------------------------------//
175
175
176
+ static Value boolCst (ImplicitLocOpBuilder &builder, bool value) {
177
+ return builder.create <arith::ConstantOp>(builder.getBoolAttr (value));
178
+ }
179
+
176
180
static Value floatCst (ImplicitLocOpBuilder &builder, float value,
177
181
Type elementType) {
178
182
assert ((elementType.isF16 () || elementType.isF32 ()) &&
@@ -1118,6 +1122,103 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
1118
1122
return success ();
1119
1123
}
1120
1124
1125
+ // Approximates erfc(x) with p((x - 2) / (x + 2)), where p is a 9 degree
1126
+ // polynomial.This approximation is based on the following stackoverflow post:
1127
+ // https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
1128
+ // The stackoverflow post is in turn based on:
1129
+ // M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
1130
+ // (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36,
1131
+ // No. 153, January 1981, pp. 249-253.
1132
+ //
1133
+ // Maximum error: 2.65 ulps
1134
+ LogicalResult
1135
+ ErfcPolynomialApproximation::matchAndRewrite (math::ErfcOp op,
1136
+ PatternRewriter &rewriter) const {
1137
+ Value x = op.getOperand ();
1138
+ Type et = getElementTypeOrSelf (x);
1139
+
1140
+ if (!et.isF32 ())
1141
+ return rewriter.notifyMatchFailure (op, " only f32 type is supported." );
1142
+ std::optional<VectorShape> shape = vectorShape (x);
1143
+
1144
+ ImplicitLocOpBuilder builder (op->getLoc (), rewriter);
1145
+ auto bcast = [&](Value value) -> Value {
1146
+ return broadcast (builder, value, shape);
1147
+ };
1148
+
1149
+ Value trueValue = bcast (boolCst (builder, true ));
1150
+ Value zero = bcast (floatCst (builder, 0 .0f , et));
1151
+ Value one = bcast (floatCst (builder, 1 .0f , et));
1152
+ Value onehalf = bcast (floatCst (builder, 0 .5f , et));
1153
+ Value neg4 = bcast (floatCst (builder, -4 .0f , et));
1154
+ Value neg2 = bcast (floatCst (builder, -2 .0f , et));
1155
+ Value pos2 = bcast (floatCst (builder, 2 .0f , et));
1156
+ Value posInf = bcast (floatCst (builder, INFINITY, et));
1157
+ Value clampVal = bcast (floatCst (builder, 10 .0546875f , et));
1158
+
1159
+ Value a = builder.create <math::AbsFOp>(x);
1160
+ Value p = builder.create <arith::AddFOp>(a, pos2);
1161
+ Value r = builder.create <arith::DivFOp>(one, p);
1162
+ Value q = builder.create <math::FmaOp>(neg4, r, one);
1163
+ Value t = builder.create <math::FmaOp>(builder.create <arith::AddFOp>(q, one),
1164
+ neg2, a);
1165
+ Value e = builder.create <math::FmaOp>(builder.create <arith::NegFOp>(a), q, t);
1166
+ q = builder.create <math::FmaOp>(r, e, q);
1167
+
1168
+ p = bcast (floatCst (builder, -0x1 .a4a000p -12f , et)); // -4.01139259e-4
1169
+ Value c1 = bcast (floatCst (builder, -0x1 .42a260p-10f , et)); // -1.23075210e-3
1170
+ p = builder.create <math::FmaOp>(p, q, c1);
1171
+ Value c2 = bcast (floatCst (builder, 0x1 .585714p-10f , et)); // 1.31355342e-3
1172
+ p = builder.create <math::FmaOp>(p, q, c2);
1173
+ Value c3 = bcast (floatCst (builder, 0x1 .1adcc4p-07f , et)); // 8.63227434e-3
1174
+ p = builder.create <math::FmaOp>(p, q, c3);
1175
+ Value c4 = bcast (floatCst (builder, -0x1 .081b82p-07f , et)); // -8.05991981e-3
1176
+ p = builder.create <math::FmaOp>(p, q, c4);
1177
+ Value c5 = bcast (floatCst (builder, -0x1 .bc0b6ap -05f , et)); // -5.42046614e-2
1178
+ p = builder.create <math::FmaOp>(p, q, c5);
1179
+ Value c6 = bcast (floatCst (builder, 0x1 .4ffc46p-03f , et)); // 1.64055392e-1
1180
+ p = builder.create <math::FmaOp>(p, q, c6);
1181
+ Value c7 = bcast (floatCst (builder, -0x1 .540840p-03f , et)); // -1.66031361e-1
1182
+ p = builder.create <math::FmaOp>(p, q, c7);
1183
+ Value c8 = bcast (floatCst (builder, -0x1 .7bf616p-04f , et)); // -9.27639827e-2
1184
+ p = builder.create <math::FmaOp>(p, q, c8);
1185
+ Value c9 = bcast (floatCst (builder, 0x1 .1ba03ap-02f , et)); // 2.76978403e-1
1186
+ p = builder.create <math::FmaOp>(p, q, c9);
1187
+
1188
+ Value d = builder.create <math::FmaOp>(pos2, a, one);
1189
+ r = builder.create <arith::DivFOp>(one, d);
1190
+ q = builder.create <math::FmaOp>(p, r, r);
1191
+ Value negfa = builder.create <arith::NegFOp>(a);
1192
+ Value fmaqah = builder.create <math::FmaOp>(q, negfa, onehalf);
1193
+ Value psubq = builder.create <arith::SubFOp>(p, q);
1194
+ e = builder.create <math::FmaOp>(fmaqah, pos2, psubq);
1195
+ r = builder.create <math::FmaOp>(e, r, q);
1196
+
1197
+ Value s = builder.create <arith::MulFOp>(a, a);
1198
+ e = builder.create <math::ExpOp>(builder.create <arith::NegFOp>(s));
1199
+
1200
+ t = builder.create <math::FmaOp>(builder.create <arith::NegFOp>(a), a, s);
1201
+ r = builder.create <math::FmaOp>(
1202
+ r, e,
1203
+ builder.create <arith::MulFOp>(builder.create <arith::MulFOp>(r, e), t));
1204
+
1205
+ Value isNotLessThanInf = builder.create <arith::XOrIOp>(
1206
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
1207
+ trueValue);
1208
+ r = builder.create <arith::SelectOp>(isNotLessThanInf,
1209
+ builder.create <arith::AddFOp>(x, x), r);
1210
+ Value isGreaterThanClamp =
1211
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
1212
+ r = builder.create <arith::SelectOp>(isGreaterThanClamp, zero, r);
1213
+
1214
+ Value isNegative =
1215
+ builder.create <arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
1216
+ r = builder.create <arith::SelectOp>(
1217
+ isNegative, builder.create <arith::SubFOp>(pos2, r), r);
1218
+
1219
+ rewriter.replaceOp (op, r);
1220
+ return success ();
1221
+ }
1121
1222
// ----------------------------------------------------------------------------//
1122
1223
// Exp approximation.
1123
1224
// ----------------------------------------------------------------------------//
@@ -1667,6 +1768,11 @@ void mlir::populatePolynomialApproximateErfPattern(
1667
1768
patterns.add <ErfPolynomialApproximation>(patterns.getContext ());
1668
1769
}
1669
1770
1771
+ void mlir::populatePolynomialApproximateErfcPattern (
1772
+ RewritePatternSet &patterns) {
1773
+ patterns.add <ErfcPolynomialApproximation>(patterns.getContext ());
1774
+ }
1775
+
1670
1776
template <typename OpType>
1671
1777
static void
1672
1778
populateMathF32ExpansionPattern (RewritePatternSet &patterns,
@@ -1690,6 +1796,7 @@ void mlir::populateMathF32ExpansionPatterns(
1690
1796
populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
1691
1797
populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
1692
1798
populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
1799
+ populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
1693
1800
populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
1694
1801
populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
1695
1802
populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
@@ -1734,6 +1841,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
1734
1841
CosOp, SinAndCosApproximation<false , math::CosOp>>(patterns, predicate);
1735
1842
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
1736
1843
patterns, predicate);
1844
+ populateMathPolynomialApproximationPattern<ErfcOp,
1845
+ ErfcPolynomialApproximation>(
1846
+ patterns, predicate);
1737
1847
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
1738
1848
patterns, predicate);
1739
1849
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
@@ -1760,9 +1870,10 @@ void mlir::populateMathPolynomialApproximationPatterns(
1760
1870
{math::AtanOp::getOperationName (), math::Atan2Op::getOperationName (),
1761
1871
math::TanhOp::getOperationName (), math::LogOp::getOperationName (),
1762
1872
math::Log2Op::getOperationName (), math::Log1pOp::getOperationName (),
1763
- math::ErfOp::getOperationName (), math::ExpOp::getOperationName (),
1764
- math::ExpM1Op::getOperationName (), math::CbrtOp::getOperationName (),
1765
- math::SinOp::getOperationName (), math::CosOp::getOperationName ()},
1873
+ math::ErfOp::getOperationName (), math::ErfcOp::getOperationName (),
1874
+ math::ExpOp::getOperationName (), math::ExpM1Op::getOperationName (),
1875
+ math::CbrtOp::getOperationName (), math::SinOp::getOperationName (),
1876
+ math::CosOp::getOperationName ()},
1766
1877
name);
1767
1878
});
1768
1879
@@ -1774,8 +1885,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
1774
1885
math::TanhOp::getOperationName (), math::LogOp::getOperationName (),
1775
1886
math::Log2Op::getOperationName (),
1776
1887
math::Log1pOp::getOperationName (), math::ErfOp::getOperationName (),
1777
- math::AsinOp::getOperationName (), math::AcosOp::getOperationName (),
1778
- math::ExpOp::getOperationName (), math::ExpM1Op::getOperationName (),
1888
+ math::ErfcOp::getOperationName (), math::AsinOp::getOperationName (),
1889
+ math::AcosOp::getOperationName (), math::ExpOp::getOperationName (),
1890
+ math::ExpM1Op::getOperationName (),
1779
1891
math::CbrtOp::getOperationName (), math::SinOp::getOperationName (),
1780
1892
math::CosOp::getOperationName ()},
1781
1893
name);
0 commit comments