@@ -43,6 +43,7 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
4343 PatternRewriter &rewriter) const {
4444 Location loc = op.getLoc ();
4545 Value x = op.getLhs ();
46+ auto fmf = op.getFastmathAttr ().getValue ();
4647
4748 FloatAttr scalarExponent;
4849 DenseFPElementsAttr vectorExponent;
@@ -78,14 +79,14 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
7879
7980 // Replace `pow(x, 2.0)` with `x * x`.
8081 if (isExponentValue (2.0 )) {
81- rewriter.replaceOpWithNewOp <arith::MulFOp>(op, ValueRange ({ x, x}) );
82+ rewriter.replaceOpWithNewOp <arith::MulFOp>(op, x, x, fmf );
8283 return success ();
8384 }
8485
8586 // Replace `pow(x, 3.0)` with `x * x * x`.
8687 if (isExponentValue (3.0 )) {
87- Value square = arith::MulFOp::create (rewriter, loc, ValueRange ({ x, x}) );
88- rewriter.replaceOpWithNewOp <arith::MulFOp>(op, ValueRange ({ x, square}) );
88+ Value square = arith::MulFOp::create (rewriter, loc, x, x, fmf );
89+ rewriter.replaceOpWithNewOp <arith::MulFOp>(op, x, square, fmf );
8990 return success ();
9091 }
9192
@@ -94,28 +95,27 @@ PowFStrengthReduction::matchAndRewrite(math::PowFOp op,
9495 Value one = arith::ConstantOp::create (
9596 rewriter, loc,
9697 rewriter.getFloatAttr (getElementTypeOrSelf (op.getType ()), 1.0 ));
97- rewriter.replaceOpWithNewOp <arith::DivFOp>(op, ValueRange ({ bcast (one), x}) );
98+ rewriter.replaceOpWithNewOp <arith::DivFOp>(op, bcast (one), x, fmf );
9899 return success ();
99100 }
100101
101102 // Replace `pow(x, 0.5)` with `sqrt(x)`.
102103 if (isExponentValue (0.5 )) {
103- rewriter.replaceOpWithNewOp <math::SqrtOp>(op, x);
104+ rewriter.replaceOpWithNewOp <math::SqrtOp>(op, x, fmf );
104105 return success ();
105106 }
106107
107108 // Replace `pow(x, -0.5)` with `rsqrt(x)`.
108109 if (isExponentValue (-0.5 )) {
109- rewriter.replaceOpWithNewOp <math::RsqrtOp>(op, x);
110+ rewriter.replaceOpWithNewOp <math::RsqrtOp>(op, x, fmf );
110111 return success ();
111112 }
112113
113114 // Replace `pow(x, 0.75)` with `sqrt(sqrt(x)) * sqrt(x)`.
114115 if (isExponentValue (0.75 )) {
115- Value powHalf = math::SqrtOp::create (rewriter, loc, x);
116- Value powQuarter = math::SqrtOp::create (rewriter, loc, powHalf);
117- rewriter.replaceOpWithNewOp <arith::MulFOp>(op,
118- ValueRange{powHalf, powQuarter});
116+ Value powHalf = math::SqrtOp::create (rewriter, loc, x, fmf);
117+ Value powQuarter = math::SqrtOp::create (rewriter, loc, powHalf, fmf);
118+ rewriter.replaceOpWithNewOp <arith::MulFOp>(op, powHalf, powQuarter, fmf);
119119 return success ();
120120 }
121121
0 commit comments