@@ -331,7 +331,8 @@ LogicalResult insertCasts(Operation *op, PatternRewriter &rewriter) {
331
331
SmallVector<Value> operands;
332
332
for (auto operand : op->getOperands ())
333
333
operands.push_back (rewriter.create <arith::ExtFOp>(loc, newType, operand));
334
- auto result = rewriter.create <math::Atan2Op>(loc, newType, operands);
334
+ auto result =
335
+ rewriter.create <T>(loc, TypeRange{newType}, operands, op->getAttrs ());
335
336
rewriter.replaceOpWithNewOp <arith::TruncFOp>(op, origType, result);
336
337
return success ();
337
338
}
@@ -1381,13 +1382,24 @@ RsqrtApproximation::matchAndRewrite(math::RsqrtOp op,
1381
1382
void mlir::populateMathPolynomialApproximationPatterns (
1382
1383
RewritePatternSet &patterns,
1383
1384
const MathPolynomialApproximationOptions &options) {
1385
+ // Patterns for leveraging existing f32 lowerings on other data types.
1386
+ patterns
1387
+ .add <ReuseF32Expansion<math::AtanOp>, ReuseF32Expansion<math::Atan2Op>,
1388
+ ReuseF32Expansion<math::TanhOp>, ReuseF32Expansion<math::LogOp>,
1389
+ ReuseF32Expansion<math::Log2Op>, ReuseF32Expansion<math::Log1pOp>,
1390
+ ReuseF32Expansion<math::ErfOp>, ReuseF32Expansion<math::ExpOp>,
1391
+ ReuseF32Expansion<math::ExpM1Op>, ReuseF32Expansion<math::CbrtOp>,
1392
+ ReuseF32Expansion<math::SinOp>, ReuseF32Expansion<math::CosOp>>(
1393
+ patterns.getContext ());
1394
+
1384
1395
patterns.add <AtanApproximation, Atan2Approximation, TanhApproximation,
1385
1396
LogApproximation, Log2Approximation, Log1pApproximation,
1386
1397
ErfPolynomialApproximation, ExpApproximation, ExpM1Approximation,
1387
- CbrtApproximation, ReuseF32Expansion<math::Atan2Op>,
1388
- SinAndCosApproximation<true , math::SinOp>,
1398
+ CbrtApproximation, SinAndCosApproximation<true , math::SinOp>,
1389
1399
SinAndCosApproximation<false , math::CosOp>>(
1390
1400
patterns.getContext ());
1391
- if (options.enableAvx2 )
1392
- patterns.add <RsqrtApproximation>(patterns.getContext ());
1401
+ if (options.enableAvx2 ) {
1402
+ patterns.add <RsqrtApproximation, ReuseF32Expansion<math::RsqrtOp>>(
1403
+ patterns.getContext ());
1404
+ }
1393
1405
}
0 commit comments