Skip to content

Commit 4b463ce

Browse files
committed
[mlir][math] Propagate fast math attrs in AlgebraicSimplification
1 parent 5378706 commit 4b463ce

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

mlir/lib/Dialect/Math/Transforms/AlgebraicSimplification.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
4343
PatternRewriter &rewriter) const {
4444
Location loc = op.getLoc();
4545
Value x = op.getLhs();
46+
auto fmf = op.getFastmathAttr().getValue();
4647

4748
FloatAttr scalarExponent;
4849
DenseFPElementsAttr vectorExponent;
@@ -78,14 +79,14 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
7879

7980
// Replace `pow(x, 2.0)` with `x * x`.
8081
if (isExponentValue(2.0)) {
81-
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, x}));
82+
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, x, fmf);
8283
return success();
8384
}
8485

8586
// Replace `pow(x, 3.0)` with `x * x * x`.
8687
if (isExponentValue(3.0)) {
87-
Value square = arith::MulFOp::create(rewriter, loc, ValueRange({x, x}));
88-
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, ValueRange({x, square}));
88+
Value square = arith::MulFOp::create(rewriter, loc, x, x, fmf);
89+
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, x, square, fmf);
8990
return success();
9091
}
9192

@@ -94,28 +95,27 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
9495
Value one = arith::ConstantOp::create(
9596
rewriter, loc,
9697
rewriter.getFloatAttr(getElementTypeOrSelf(op.getType()), 1.0));
97-
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, ValueRange({bcast(one), x}));
98+
rewriter.replaceOpWithNewOp<arith::DivFOp>(op, bcast(one), x, fmf);
9899
return success();
99100
}
100101

101102
// Replace `pow(x, 0.5)` with `sqrt(x)`.
102103
if (isExponentValue(0.5)) {
103-
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x);
104+
rewriter.replaceOpWithNewOp<math::SqrtOp>(op, x, fmf);
104105
return success();
105106
}
106107

107108
// Replace `pow(x, -0.5)` with `rsqrt(x)`.
108109
if (isExponentValue(-0.5)) {
109-
rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x);
110+
rewriter.replaceOpWithNewOp<math::RsqrtOp>(op, x, fmf);
110111
return success();
111112
}
112113

113114
// Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
114115
if (isExponentValue(0.75)) {
115-
Value powHalf = math::SqrtOp::create(rewriter, loc, x);
116-
Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf);
117-
rewriter.replaceOpWithNewOp<arith::MulFOp>(op,
118-
ValueRange{powHalf, powQuarter});
116+
Value powHalf = math::SqrtOp::create(rewriter, loc, x, fmf);
117+
Value powQuarter = math::SqrtOp::create(rewriter, loc, powHalf, fmf);
118+
rewriter.replaceOpWithNewOp<arith::MulFOp>(op, powHalf, powQuarter, fmf);
119119
return success();
120120
}
121121

0 commit comments

Comments
 (0)