Skip to content

Commit 8806311

Browse files
authored
[MLIR][Math] Add erfc to math dialect (llvm#126439)
This patch adds the erfc op to the math dialect. It also does lowering of the math.erfc op to libm calls. There is also a f32 polynomial approximation for the function based on https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf This is in turn based on M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36, No. 153, January 1981, pp. 249-253. The code has a ULP error less than 3, which was tested, and MLIR test values were verified against the C implementation.
1 parent e1a393e commit 8806311

File tree

9 files changed

+353
-5
lines changed

9 files changed

+353
-5
lines changed

mlir/include/mlir/Dialect/Math/IR/MathOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,31 @@ def Math_ErfOp : Math_FloatUnaryOp<"erf"> {
560560
let hasFolder = 1;
561561
}
562562

563+
//===----------------------------------------------------------------------===//
564+
// ErfcOp
565+
//===----------------------------------------------------------------------===//
566+
567+
def Math_ErfcOp : Math_FloatUnaryOp<"erfc"> {
568+
let summary = "complementary error function of the specified value";
569+
let description = [{
570+
571+
The `erfc` operation computes the complementary error function, defined as
572+
1-erf(x). This function is part of libm and is needed for accuracy, since
573+
simply calculating 1-erf(x) when x is close to 1 will give inaccurate results.
574+
It takes one operand of floating point type (i.e., scalar,
575+
tensor or vector) and returns one result of the same type. It has no
576+
standard attributes.
577+
578+
Example:
579+
580+
```mlir
581+
// Scalar error function value.
582+
%a = math.erfc %b : f64
583+
```
584+
}];
585+
let hasFolder = 1;
586+
}
587+
563588

564589
//===----------------------------------------------------------------------===//
565590
// ExpOp

mlir/include/mlir/Dialect/Math/Transforms/Approximation.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,14 @@ struct ErfPolynomialApproximation : public OpRewritePattern<math::ErfOp> {
2323
PatternRewriter &rewriter) const final;
2424
};
2525

26+
struct ErfcPolynomialApproximation : public OpRewritePattern<math::ErfcOp> {
27+
public:
28+
using OpRewritePattern::OpRewritePattern;
29+
30+
LogicalResult matchAndRewrite(math::ErfcOp op,
31+
PatternRewriter &rewriter) const final;
32+
};
33+
2634
} // namespace math
2735
} // namespace mlir
2836

mlir/include/mlir/Dialect/Math/Transforms/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ struct MathPolynomialApproximationOptions {
4747

4848
void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
4949
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
50+
void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns);
5051

5152
// Adds patterns to convert to f32 around math functions for which `predicate`
5253
// returns true.

mlir/lib/Conversion/MathToLibm/MathToLibm.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
181181
populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
182182
populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
183183
populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
184+
populatePatternsForOp<math::ErfcOp>(patterns, benefit, ctx, "erfcf", "erfc");
184185
populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
185186
populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
186187
populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f",

mlir/lib/Dialect/Math/IR/MathOps.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,24 @@ OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
332332
});
333333
}
334334

335+
//===----------------------------------------------------------------------===//
336+
// ErfcOp folder
337+
//===----------------------------------------------------------------------===//
338+
339+
OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
340+
return constFoldUnaryOpConditional<FloatAttr>(
341+
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
342+
switch (APFloat::SemanticsToEnum(a.getSemantics())) {
343+
case APFloat::Semantics::S_IEEEdouble:
344+
return APFloat(erfc(a.convertToDouble()));
345+
case APFloat::Semantics::S_IEEEsingle:
346+
return APFloat(erfcf(a.convertToFloat()));
347+
default:
348+
return {};
349+
}
350+
});
351+
}
352+
335353
//===----------------------------------------------------------------------===//
336354
// IPowIOp folder
337355
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp

Lines changed: 117 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
173173
// Helper functions to create constants.
174174
//----------------------------------------------------------------------------//
175175

176+
static Value boolCst(ImplicitLocOpBuilder &builder, bool value) {
177+
return builder.create<arith::ConstantOp>(builder.getBoolAttr(value));
178+
}
179+
176180
static Value floatCst(ImplicitLocOpBuilder &builder, float value,
177181
Type elementType) {
178182
assert((elementType.isF16() || elementType.isF32()) &&
@@ -1118,6 +1122,103 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
11181122
return success();
11191123
}
11201124

1125+
// Approximates erfc(x) with p((x - 2) / (x + 2)), where p is a 9 degree
1126+
// polynomial.This approximation is based on the following stackoverflow post:
1127+
// https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
1128+
// The stackoverflow post is in turn based on:
1129+
// M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
1130+
// (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36,
1131+
// No. 153, January 1981, pp. 249-253.
1132+
//
1133+
// Maximum error: 2.65 ulps
1134+
LogicalResult
1135+
ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op,
1136+
PatternRewriter &rewriter) const {
1137+
Value x = op.getOperand();
1138+
Type et = getElementTypeOrSelf(x);
1139+
1140+
if (!et.isF32())
1141+
return rewriter.notifyMatchFailure(op, "only f32 type is supported.");
1142+
std::optional<VectorShape> shape = vectorShape(x);
1143+
1144+
ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
1145+
auto bcast = [&](Value value) -> Value {
1146+
return broadcast(builder, value, shape);
1147+
};
1148+
1149+
Value trueValue = bcast(boolCst(builder, true));
1150+
Value zero = bcast(floatCst(builder, 0.0f, et));
1151+
Value one = bcast(floatCst(builder, 1.0f, et));
1152+
Value onehalf = bcast(floatCst(builder, 0.5f, et));
1153+
Value neg4 = bcast(floatCst(builder, -4.0f, et));
1154+
Value neg2 = bcast(floatCst(builder, -2.0f, et));
1155+
Value pos2 = bcast(floatCst(builder, 2.0f, et));
1156+
Value posInf = bcast(floatCst(builder, INFINITY, et));
1157+
Value clampVal = bcast(floatCst(builder, 10.0546875f, et));
1158+
1159+
Value a = builder.create<math::AbsFOp>(x);
1160+
Value p = builder.create<arith::AddFOp>(a, pos2);
1161+
Value r = builder.create<arith::DivFOp>(one, p);
1162+
Value q = builder.create<math::FmaOp>(neg4, r, one);
1163+
Value t = builder.create<math::FmaOp>(builder.create<arith::AddFOp>(q, one),
1164+
neg2, a);
1165+
Value e = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), q, t);
1166+
q = builder.create<math::FmaOp>(r, e, q);
1167+
1168+
p = bcast(floatCst(builder, -0x1.a4a000p-12f, et)); // -4.01139259e-4
1169+
Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3
1170+
p = builder.create<math::FmaOp>(p, q, c1);
1171+
Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); // 1.31355342e-3
1172+
p = builder.create<math::FmaOp>(p, q, c2);
1173+
Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3
1174+
p = builder.create<math::FmaOp>(p, q, c3);
1175+
Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3
1176+
p = builder.create<math::FmaOp>(p, q, c4);
1177+
Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2
1178+
p = builder.create<math::FmaOp>(p, q, c5);
1179+
Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); // 1.64055392e-1
1180+
p = builder.create<math::FmaOp>(p, q, c6);
1181+
Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1
1182+
p = builder.create<math::FmaOp>(p, q, c7);
1183+
Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2
1184+
p = builder.create<math::FmaOp>(p, q, c8);
1185+
Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1
1186+
p = builder.create<math::FmaOp>(p, q, c9);
1187+
1188+
Value d = builder.create<math::FmaOp>(pos2, a, one);
1189+
r = builder.create<arith::DivFOp>(one, d);
1190+
q = builder.create<math::FmaOp>(p, r, r);
1191+
Value negfa = builder.create<arith::NegFOp>(a);
1192+
Value fmaqah = builder.create<math::FmaOp>(q, negfa, onehalf);
1193+
Value psubq = builder.create<arith::SubFOp>(p, q);
1194+
e = builder.create<math::FmaOp>(fmaqah, pos2, psubq);
1195+
r = builder.create<math::FmaOp>(e, r, q);
1196+
1197+
Value s = builder.create<arith::MulFOp>(a, a);
1198+
e = builder.create<math::ExpOp>(builder.create<arith::NegFOp>(s));
1199+
1200+
t = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), a, s);
1201+
r = builder.create<math::FmaOp>(
1202+
r, e,
1203+
builder.create<arith::MulFOp>(builder.create<arith::MulFOp>(r, e), t));
1204+
1205+
Value isNotLessThanInf = builder.create<arith::XOrIOp>(
1206+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
1207+
trueValue);
1208+
r = builder.create<arith::SelectOp>(isNotLessThanInf,
1209+
builder.create<arith::AddFOp>(x, x), r);
1210+
Value isGreaterThanClamp =
1211+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
1212+
r = builder.create<arith::SelectOp>(isGreaterThanClamp, zero, r);
1213+
1214+
Value isNegative =
1215+
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
1216+
r = builder.create<arith::SelectOp>(
1217+
isNegative, builder.create<arith::SubFOp>(pos2, r), r);
1218+
1219+
rewriter.replaceOp(op, r);
1220+
return success();
1221+
}
11211222
//----------------------------------------------------------------------------//
11221223
// Exp approximation.
11231224
//----------------------------------------------------------------------------//
@@ -1667,6 +1768,11 @@ void mlir::populatePolynomialApproximateErfPattern(
16671768
patterns.add<ErfPolynomialApproximation>(patterns.getContext());
16681769
}
16691770

1771+
void mlir::populatePolynomialApproximateErfcPattern(
1772+
RewritePatternSet &patterns) {
1773+
patterns.add<ErfcPolynomialApproximation>(patterns.getContext());
1774+
}
1775+
16701776
template <typename OpType>
16711777
static void
16721778
populateMathF32ExpansionPattern(RewritePatternSet &patterns,
@@ -1690,6 +1796,7 @@ void mlir::populateMathF32ExpansionPatterns(
16901796
populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
16911797
populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
16921798
populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
1799+
populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
16931800
populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
16941801
populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
16951802
populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
@@ -1734,6 +1841,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
17341841
CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
17351842
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
17361843
patterns, predicate);
1844+
populateMathPolynomialApproximationPattern<ErfcOp,
1845+
ErfcPolynomialApproximation>(
1846+
patterns, predicate);
17371847
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
17381848
patterns, predicate);
17391849
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
@@ -1760,9 +1870,10 @@ void mlir::populateMathPolynomialApproximationPatterns(
17601870
{math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
17611871
math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
17621872
math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
1763-
math::ErfOp::getOperationName(), math::ExpOp::getOperationName(),
1764-
math::ExpM1Op::getOperationName(), math::CbrtOp::getOperationName(),
1765-
math::SinOp::getOperationName(), math::CosOp::getOperationName()},
1873+
math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(),
1874+
math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
1875+
math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
1876+
math::CosOp::getOperationName()},
17661877
name);
17671878
});
17681879

@@ -1774,8 +1885,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
17741885
math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
17751886
math::Log2Op::getOperationName(),
17761887
math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
1777-
math::AsinOp::getOperationName(), math::AcosOp::getOperationName(),
1778-
math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
1888+
math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(),
1889+
math::AcosOp::getOperationName(), math::ExpOp::getOperationName(),
1890+
math::ExpM1Op::getOperationName(),
17791891
math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
17801892
math::CosOp::getOperationName()},
17811893
name);

0 commit comments

Comments
 (0)