@@ -650,6 +650,11 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
650
650
rewriter.getMultiDimIdentityMap (rank));
651
651
SmallVector<utils::IteratorType> iteratorTypes (rank,
652
652
utils::IteratorType::parallel);
653
+
654
+ // Check 'fast-math'. If present, propagate it.
655
+ auto fmfOpInterface =
656
+ llvm::dyn_cast<arith::ArithFastMathInterface>(op.getOperation ());
657
+
653
658
auto genericOp = linalg::GenericOp::create (
654
659
rewriter, loc, tensorType,
655
660
op->getOperands (), // inputs
@@ -658,12 +663,32 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
658
663
[&](OpBuilder &builder, Location loc, ValueRange args) {
659
664
Value res;
660
665
if (args.size () == 2 ) {
661
- res =
662
- builder.create <OpTy>(loc, args[1 ].getType (), ValueRange{args[0 ]})
663
- .getResult ();
666
+ if (fmfOpInterface) {
667
+ auto attr = fmfOpInterface.getFastMathFlagsAttr ();
668
+ auto fmf = rewriter.getNamedAttr (" fastmath" , attr);
669
+ res = builder
670
+ .create <OpTy>(loc, args[1 ].getType (), ValueRange{args[0 ]},
671
+ fmf)
672
+ .getResult ();
673
+ } else {
674
+ res = builder
675
+ .create <OpTy>(loc, args[1 ].getType (), ValueRange{args[0 ]})
676
+ .getResult ();
677
+ }
664
678
} else if (args.size () == 3 ) {
665
- res = builder.create <OpTy>(loc, args[2 ].getType (),
666
- ValueRange{args[0 ], args[1 ]});
679
+ if (fmfOpInterface) {
680
+ auto attr = fmfOpInterface.getFastMathFlagsAttr ();
681
+ auto fmf = rewriter.getNamedAttr (" fastmath" , attr);
682
+ res = builder
683
+ .create <OpTy>(loc, args[2 ].getType (),
684
+ ValueRange{args[0 ], args[1 ]}, fmf)
685
+ .getResult ();
686
+ } else {
687
+ res = builder
688
+ .create <OpTy>(loc, args[2 ].getType (),
689
+ ValueRange{args[0 ], args[1 ]})
690
+ .getResult ();
691
+ }
667
692
} else
668
693
llvm_unreachable (" did not expect ops other than nary and binary" );
669
694
linalg::YieldOp::create (builder, loc, res);
0 commit comments