Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions mlir/include/mlir/Dialect/Math/IR/MathOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,31 @@ def Math_ErfOp : Math_FloatUnaryOp<"erf"> {
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
// ErfcOp
//===----------------------------------------------------------------------===//

def Math_ErfcOp : Math_FloatUnaryOp<"erfc"> {
let summary = "complementary error function of the specified value";
let description = [{

The `erfc` operation computes the complementary error function, defined as
1-erf(x). This function is part of libm and is needed for accuracy, since
simply calculating 1-erf(x) when x is close to 1 will give inaccurate results.
It takes one operand of floating point type (i.e., scalar,
tensor or vector) and returns one result of the same type. It has no
standard attributes.

Example:

```mlir
// Scalar error function value.
%a = math.erfc %b : f64
```
}];
let hasFolder = 1;
}


//===----------------------------------------------------------------------===//
// ExpOp
Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Math/Transforms/Approximation.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ struct ErfPolynomialApproximation : public OpRewritePattern<math::ErfOp> {
PatternRewriter &rewriter) const final;
};

struct ErfcPolynomialApproximation : public OpRewritePattern<math::ErfcOp> {
public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(math::ErfcOp op,
PatternRewriter &rewriter) const final;
};

} // namespace math
} // namespace mlir

Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Math/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ struct MathPolynomialApproximationOptions {

void populatePolynomialApproximateTanhPattern(RewritePatternSet &patterns);
void populatePolynomialApproximateErfPattern(RewritePatternSet &patterns);
void populatePolynomialApproximateErfcPattern(RewritePatternSet &patterns);

// Adds patterns to convert to f32 around math functions for which `predicate`
// returns true.
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Conversion/MathToLibm/MathToLibm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ void mlir::populateMathToLibmConversionPatterns(RewritePatternSet &patterns,
populatePatternsForOp<math::CosOp>(patterns, benefit, ctx, "cosf", "cos");
populatePatternsForOp<math::CoshOp>(patterns, benefit, ctx, "coshf", "cosh");
populatePatternsForOp<math::ErfOp>(patterns, benefit, ctx, "erff", "erf");
populatePatternsForOp<math::ErfcOp>(patterns, benefit, ctx, "erfcf", "erfc");
populatePatternsForOp<math::ExpOp>(patterns, benefit, ctx, "expf", "exp");
populatePatternsForOp<math::Exp2Op>(patterns, benefit, ctx, "exp2f", "exp2");
populatePatternsForOp<math::ExpM1Op>(patterns, benefit, ctx, "expm1f",
Expand Down
18 changes: 18 additions & 0 deletions mlir/lib/Dialect/Math/IR/MathOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,24 @@ OpFoldResult math::ErfOp::fold(FoldAdaptor adaptor) {
});
}

//===----------------------------------------------------------------------===//
// ErfcOp folder
//===----------------------------------------------------------------------===//

OpFoldResult math::ErfcOp::fold(FoldAdaptor adaptor) {
return constFoldUnaryOpConditional<FloatAttr>(
adaptor.getOperands(), [](const APFloat &a) -> std::optional<APFloat> {
switch (APFloat::SemanticsToEnum(a.getSemantics())) {
case APFloat::Semantics::S_IEEEdouble:
return APFloat(erfc(a.convertToDouble()));
case APFloat::Semantics::S_IEEEsingle:
return APFloat(erfcf(a.convertToFloat()));
default:
return {};
}
});
}

//===----------------------------------------------------------------------===//
// IPowIOp folder
//===----------------------------------------------------------------------===//
Expand Down
122 changes: 117 additions & 5 deletions mlir/lib/Dialect/Math/Transforms/PolynomialApproximation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ handleMultidimensionalVectors(ImplicitLocOpBuilder &builder,
// Helper functions to create constants.
//----------------------------------------------------------------------------//

static Value boolCst(ImplicitLocOpBuilder &builder, bool value) {
return builder.create<arith::ConstantOp>(builder.getBoolAttr(value));
}

static Value floatCst(ImplicitLocOpBuilder &builder, float value,
Type elementType) {
assert((elementType.isF16() || elementType.isF32()) &&
Expand Down Expand Up @@ -1118,6 +1122,103 @@ ErfPolynomialApproximation::matchAndRewrite(math::ErfOp op,
return success();
}

// Approximates erfc(x) with p((x - 2) / (x + 2)), where p is a 9 degree
// polynomial.This approximation is based on the following stackoverflow post:
// https://stackoverflow.com/questions/35966695/vectorizable-implementation-of-complementary-error-function-erfcf
// The stackoverflow post is in turn based on:
// M. M. Shepherd and J. G. Laframboise, "Chebyshev Approximation of
// (1+2x)exp(x^2)erfc x in 0 <= x < INF", Mathematics of Computation, Vol. 36,
// No. 153, January 1981, pp. 249-253.
//
// Maximum error: 2.65 ulps
LogicalResult
ErfcPolynomialApproximation::matchAndRewrite(math::ErfcOp op,
PatternRewriter &rewriter) const {
Value x = op.getOperand();
Type et = getElementTypeOrSelf(x);

if (!et.isF32())
return rewriter.notifyMatchFailure(op, "only f32 type is supported.");
std::optional<VectorShape> shape = vectorShape(x);

ImplicitLocOpBuilder builder(op->getLoc(), rewriter);
auto bcast = [&](Value value) -> Value {
return broadcast(builder, value, shape);
};

Value trueValue = bcast(boolCst(builder, true));
Value zero = bcast(floatCst(builder, 0.0f, et));
Value one = bcast(floatCst(builder, 1.0f, et));
Value onehalf = bcast(floatCst(builder, 0.5f, et));
Value neg4 = bcast(floatCst(builder, -4.0f, et));
Value neg2 = bcast(floatCst(builder, -2.0f, et));
Value pos2 = bcast(floatCst(builder, 2.0f, et));
Value posInf = bcast(floatCst(builder, INFINITY, et));
Value clampVal = bcast(floatCst(builder, 10.0546875f, et));

Value a = builder.create<math::AbsFOp>(x);
Value p = builder.create<arith::AddFOp>(a, pos2);
Value r = builder.create<arith::DivFOp>(one, p);
Value q = builder.create<math::FmaOp>(neg4, r, one);
Value t = builder.create<math::FmaOp>(builder.create<arith::AddFOp>(q, one),
neg2, a);
Value e = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), q, t);
q = builder.create<math::FmaOp>(r, e, q);

p = bcast(floatCst(builder, -0x1.a4a000p-12f, et)); // -4.01139259e-4
Value c1 = bcast(floatCst(builder, -0x1.42a260p-10f, et)); // -1.23075210e-3
p = builder.create<math::FmaOp>(p, q, c1);
Value c2 = bcast(floatCst(builder, 0x1.585714p-10f, et)); // 1.31355342e-3
p = builder.create<math::FmaOp>(p, q, c2);
Value c3 = bcast(floatCst(builder, 0x1.1adcc4p-07f, et)); // 8.63227434e-3
p = builder.create<math::FmaOp>(p, q, c3);
Value c4 = bcast(floatCst(builder, -0x1.081b82p-07f, et)); // -8.05991981e-3
p = builder.create<math::FmaOp>(p, q, c4);
Value c5 = bcast(floatCst(builder, -0x1.bc0b6ap-05f, et)); // -5.42046614e-2
p = builder.create<math::FmaOp>(p, q, c5);
Value c6 = bcast(floatCst(builder, 0x1.4ffc46p-03f, et)); // 1.64055392e-1
p = builder.create<math::FmaOp>(p, q, c6);
Value c7 = bcast(floatCst(builder, -0x1.540840p-03f, et)); // -1.66031361e-1
p = builder.create<math::FmaOp>(p, q, c7);
Value c8 = bcast(floatCst(builder, -0x1.7bf616p-04f, et)); // -9.27639827e-2
p = builder.create<math::FmaOp>(p, q, c8);
Value c9 = bcast(floatCst(builder, 0x1.1ba03ap-02f, et)); // 2.76978403e-1
p = builder.create<math::FmaOp>(p, q, c9);

Value d = builder.create<math::FmaOp>(pos2, a, one);
r = builder.create<arith::DivFOp>(one, d);
q = builder.create<math::FmaOp>(p, r, r);
Value negfa = builder.create<arith::NegFOp>(a);
Value fmaqah = builder.create<math::FmaOp>(q, negfa, onehalf);
Value psubq = builder.create<arith::SubFOp>(p, q);
e = builder.create<math::FmaOp>(fmaqah, pos2, psubq);
r = builder.create<math::FmaOp>(e, r, q);

Value s = builder.create<arith::MulFOp>(a, a);
e = builder.create<math::ExpOp>(builder.create<arith::NegFOp>(s));

t = builder.create<math::FmaOp>(builder.create<arith::NegFOp>(a), a, s);
r = builder.create<math::FmaOp>(
r, e,
builder.create<arith::MulFOp>(builder.create<arith::MulFOp>(r, e), t));

Value isNotLessThanInf = builder.create<arith::XOrIOp>(
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, a, posInf),
trueValue);
r = builder.create<arith::SelectOp>(isNotLessThanInf,
builder.create<arith::AddFOp>(x, x), r);
Value isGreaterThanClamp =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OGT, a, clampVal);
r = builder.create<arith::SelectOp>(isGreaterThanClamp, zero, r);

Value isNegative =
builder.create<arith::CmpFOp>(arith::CmpFPredicate::OLT, x, zero);
r = builder.create<arith::SelectOp>(
isNegative, builder.create<arith::SubFOp>(pos2, r), r);

rewriter.replaceOp(op, r);
return success();
}
//----------------------------------------------------------------------------//
// Exp approximation.
//----------------------------------------------------------------------------//
Expand Down Expand Up @@ -1667,6 +1768,11 @@ void mlir::populatePolynomialApproximateErfPattern(
patterns.add<ErfPolynomialApproximation>(patterns.getContext());
}

void mlir::populatePolynomialApproximateErfcPattern(
RewritePatternSet &patterns) {
patterns.add<ErfcPolynomialApproximation>(patterns.getContext());
}

template <typename OpType>
static void
populateMathF32ExpansionPattern(RewritePatternSet &patterns,
Expand All @@ -1690,6 +1796,7 @@ void mlir::populateMathF32ExpansionPatterns(
populateMathF32ExpansionPattern<math::CosOp>(patterns, predicate);
populateMathF32ExpansionPattern<math::CoshOp>(patterns, predicate);
populateMathF32ExpansionPattern<math::ErfOp>(patterns, predicate);
populateMathF32ExpansionPattern<math::ErfcOp>(patterns, predicate);
populateMathF32ExpansionPattern<math::ExpOp>(patterns, predicate);
populateMathF32ExpansionPattern<math::Exp2Op>(patterns, predicate);
populateMathF32ExpansionPattern<math::ExpM1Op>(patterns, predicate);
Expand Down Expand Up @@ -1734,6 +1841,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
CosOp, SinAndCosApproximation<false, math::CosOp>>(patterns, predicate);
populateMathPolynomialApproximationPattern<ErfOp, ErfPolynomialApproximation>(
patterns, predicate);
populateMathPolynomialApproximationPattern<ErfcOp,
ErfcPolynomialApproximation>(
patterns, predicate);
populateMathPolynomialApproximationPattern<ExpOp, ExpApproximation>(
patterns, predicate);
populateMathPolynomialApproximationPattern<ExpM1Op, ExpM1Approximation>(
Expand All @@ -1760,9 +1870,10 @@ void mlir::populateMathPolynomialApproximationPatterns(
{math::AtanOp::getOperationName(), math::Atan2Op::getOperationName(),
math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
math::Log2Op::getOperationName(), math::Log1pOp::getOperationName(),
math::ErfOp::getOperationName(), math::ExpOp::getOperationName(),
math::ExpM1Op::getOperationName(), math::CbrtOp::getOperationName(),
math::SinOp::getOperationName(), math::CosOp::getOperationName()},
math::ErfOp::getOperationName(), math::ErfcOp::getOperationName(),
math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
math::CosOp::getOperationName()},
name);
});

Expand All @@ -1774,8 +1885,9 @@ void mlir::populateMathPolynomialApproximationPatterns(
math::TanhOp::getOperationName(), math::LogOp::getOperationName(),
math::Log2Op::getOperationName(),
math::Log1pOp::getOperationName(), math::ErfOp::getOperationName(),
math::AsinOp::getOperationName(), math::AcosOp::getOperationName(),
math::ExpOp::getOperationName(), math::ExpM1Op::getOperationName(),
math::ErfcOp::getOperationName(), math::AsinOp::getOperationName(),
math::AcosOp::getOperationName(), math::ExpOp::getOperationName(),
math::ExpM1Op::getOperationName(),
math::CbrtOp::getOperationName(), math::SinOp::getOperationName(),
math::CosOp::getOperationName()},
name);
Expand Down
Loading