17
17
#include " mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
18
18
#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
19
19
#include " mlir/Dialect/Linalg/IR/Linalg.h"
20
+ #include " mlir/Dialect/Linalg/Passes.h"
20
21
#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
21
22
#include " mlir/Dialect/Tensor/IR/Tensor.h"
22
23
#include " mlir/Dialect/Utils/StaticValueUtils.h"
23
24
#include " mlir/IR/Matchers.h"
24
25
#include " mlir/IR/PatternMatch.h"
26
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
25
27
#include " llvm/ADT/STLExtras.h"
26
28
29
+ namespace mlir {
30
+ #define GEN_PASS_DEF_LINALGCONVERTTODPSPASS
31
+ #include " mlir/Dialect/Linalg/Passes.h.inc"
32
+ } // namespace mlir
33
+
34
+ #define DEBUG_TYPE " linalg-convert-to-dps"
35
+
27
36
using namespace mlir ;
28
37
using namespace mlir ::tensor;
29
38
@@ -96,7 +105,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
96
105
OpBuilder::InsertionGuard g (rewriter);
97
106
RankedTensorType resultType = padOp.getResultType ();
98
107
99
- // Examine the yielded value to decide if a linalg.generic is neede or a
108
+ // Examine the yielded value to decide if a linalg.generic is needed or a
100
109
// linalg.fill is sufficient.
101
110
Value yieldedValue =
102
111
cast<tensor::YieldOp>(padOp.getBody ()->getTerminator ()).getValue ();
@@ -603,6 +612,56 @@ Value linalg::bufferizeToAllocation(
603
612
}
604
613
605
614
namespace {
615
+ template <typename OpTy>
616
+ FailureOr<Operation *>
617
+ rewriteArithInDestinationPassingStyle (RewriterBase &rewriter, OpTy op) {
618
+ // reject ops such as `arith.constant` and `arith.select`.
619
+ auto numOperands = op->getNumOperands ();
620
+ if (numOperands == 0 || numOperands > 2 )
621
+ return failure ();
622
+
623
+ // destination passing style rewrite is only for ops on tensor types.
624
+ Type resultType = op->getResult (0 ).getType ();
625
+ auto tensorType = dyn_cast<RankedTensorType>(resultType);
626
+ if (!tensorType)
627
+ return failure ();
628
+
629
+ auto loc = op.getLoc ();
630
+ OpBuilder::InsertionGuard g (rewriter);
631
+ auto dynSizes = reifyOrComputeDynamicSizes (rewriter, op->getOperand (0 ));
632
+
633
+ // Create tensor.empty.
634
+ Value empty = tensor::EmptyOp::create (rewriter, loc, resultType, dynSizes);
635
+
636
+ // Create linalg.generic
637
+ auto rank = tensorType.getRank ();
638
+ SmallVector<AffineMap> indexingMaps (numOperands + 1 ,
639
+ rewriter.getMultiDimIdentityMap (rank));
640
+ SmallVector<utils::IteratorType> iteratorTypes (rank,
641
+ utils::IteratorType::parallel);
642
+ auto genericOp = linalg::GenericOp::create (
643
+ rewriter, loc, tensorType,
644
+ op->getOperands (), // inputs
645
+ ValueRange{empty}, // outputs
646
+ indexingMaps, iteratorTypes,
647
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
648
+ Value res;
649
+ if (args.size () == 2 ) {
650
+ res =
651
+ builder.create <OpTy>(loc, args[1 ].getType (), ValueRange{args[0 ]})
652
+ .getResult ();
653
+ } else if (args.size () == 3 ) {
654
+ res = builder.create <OpTy>(loc, args[2 ].getType (),
655
+ ValueRange{args[0 ], args[1 ]});
656
+ } else
657
+ llvm_unreachable (" did not expect ops other than nary and binary" );
658
+ linalg::YieldOp::create (builder, loc, res);
659
+ });
660
+
661
+ rewriter.replaceAllUsesWith (op, genericOp.getResult (0 ));
662
+ rewriter.eraseOp (op);
663
+ return genericOp.getOperation ();
664
+ }
606
665
607
666
template <typename OpTy>
608
667
LogicalResult rewriteOpInDestinationPassingStyle (OpTy op,
@@ -612,9 +671,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
612
671
613
672
} // namespace
614
673
674
+ #define STAMP_OUT_ARITH_DPS_FUNCS (OPTY ) \
675
+ FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle ( \
676
+ RewriterBase &rewriter, OPTY op) { \
677
+ return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op); \
678
+ }
679
+
680
+ STAMP_OUT_ARITH_DPS_FUNCS (arith::UIToFPOp)
681
+ STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp)
682
+ STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp)
683
+ STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp)
684
+
685
+ STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp)
686
+
687
+ STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp)
688
+ STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp)
689
+
615
690
void linalg::populateConvertToDestinationStylePatterns(
616
691
RewritePatternSet &patterns) {
617
692
patterns.add (rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
618
693
patterns.add (rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
619
694
patterns.add (rewriteOpInDestinationPassingStyle<tensor::PadOp>);
695
+
696
+ patterns.add (rewriteOpInDestinationPassingStyle<arith::UIToFPOp>);
697
+ patterns.add (rewriteOpInDestinationPassingStyle<arith::SIToFPOp>);
698
+ patterns.add (rewriteOpInDestinationPassingStyle<arith::FPToUIOp>);
699
+ patterns.add (rewriteOpInDestinationPassingStyle<arith::FPToSIOp>);
700
+
701
+ patterns.add (rewriteOpInDestinationPassingStyle<arith::AddIOp>);
702
+
703
+ patterns.add (rewriteOpInDestinationPassingStyle<arith::AddFOp>);
704
+ patterns.add (rewriteOpInDestinationPassingStyle<arith::DivFOp>);
620
705
}
706
+
707
+ namespace {
708
+ struct LinalgConvertToDPSPass
709
+ : public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> {
710
+ using impl::LinalgConvertToDPSPassBase<
711
+ LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
712
+
713
+ void runOnOperation () override ;
714
+ };
715
+
716
+ void LinalgConvertToDPSPass::runOnOperation () {
717
+
718
+ RewritePatternSet patterns (&getContext ());
719
+ linalg::populateConvertToDestinationStylePatterns (patterns);
720
+ if (failed (applyPatternsGreedily (getOperation (), std::move (patterns))))
721
+ signalPassFailure ();
722
+ }
723
+ } // namespace
0 commit comments