23
23
#include " mlir/Transforms/DialectConversion.h"
24
24
#include " mlir/Transforms/FoldUtils.h"
25
25
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
26
-
27
26
#include " llvm/ADT/TypeSwitch.h"
28
27
29
28
using namespace mlir ;
@@ -505,21 +504,31 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
505
504
}
506
505
507
506
template <typename LoopTy>
508
- static Optional<LinalgLoops> linalgOpToLoopsImpl (Operation *op,
509
- OpBuilder &builder) {
507
+ static Optional<LinalgLoops>
508
+ linalgOpToLoopsImpl (Operation *op, OpBuilder &builder,
509
+ ArrayRef<unsigned > interchangeVector) {
510
510
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
511
-
512
511
ScopedContext scope (builder, op->getLoc ());
513
512
514
513
// The flattened loopToOperandRangesMaps is expected to be an invertible
515
514
// permutation map (which is asserted in the inverse calculation).
516
515
auto linalgOp = cast<LinalgOp>(op);
517
516
assert (linalgOp.hasBufferSemantics () &&
518
517
" expected linalg op with buffer semantics" );
518
+
519
519
auto loopRanges = linalgOp.createLoopRanges (builder, op->getLoc ());
520
+ auto iteratorTypes = llvm::to_vector<4 >(linalgOp.iterator_types ().getValue ());
521
+
522
+ if (!interchangeVector.empty ()) {
523
+ assert (interchangeVector.size () == loopRanges.size ());
524
+ assert (interchangeVector.size () == iteratorTypes.size ());
525
+ applyPermutationToVector (loopRanges, interchangeVector);
526
+ applyPermutationToVector (iteratorTypes, interchangeVector);
527
+ }
528
+
520
529
SmallVector<Value, 4 > allIvs;
521
530
GenerateLoopNest<LoopTy>::doit (
522
- loopRanges, /* iterInitArgs*/ {}, linalgOp. iterator_types (). getValue () ,
531
+ loopRanges, /* iterInitArgs= */ {}, iteratorTypes ,
523
532
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
524
533
assert (iterArgs.empty () && " unexpected iterArgs" );
525
534
allIvs.append (ivs.begin (), ivs.end ());
@@ -552,26 +561,33 @@ namespace {
552
561
template <typename LoopType>
553
562
class LinalgRewritePattern : public RewritePattern {
554
563
public:
555
- LinalgRewritePattern () : RewritePattern(/* benefit=*/ 1 , MatchAnyOpTypeTag()) {}
564
+ LinalgRewritePattern (ArrayRef<unsigned > interchangeVector)
565
+ : RewritePattern(/* benefit=*/ 1 , MatchAnyOpTypeTag()),
566
+ interchangeVector (interchangeVector.begin(), interchangeVector.end()) {}
556
567
557
568
LogicalResult matchAndRewrite (Operation *op,
558
569
PatternRewriter &rewriter) const override {
559
570
if (!isa<LinalgOp>(op))
560
571
return failure ();
561
- if (!linalgOpToLoopsImpl<LoopType>(op, rewriter))
572
+ if (!linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector ))
562
573
return failure ();
563
574
rewriter.eraseOp (op);
564
575
return success ();
565
576
}
577
+
578
+ private:
579
+ SmallVector<unsigned , 4 > interchangeVector;
566
580
};
567
581
568
582
struct FoldAffineOp ;
569
583
} // namespace
570
584
571
585
template <typename LoopType>
572
- static void lowerLinalgToLoopsImpl (FuncOp funcOp, MLIRContext *context) {
586
+ static void lowerLinalgToLoopsImpl (FuncOp funcOp,
587
+ ArrayRef<unsigned > interchangeVector) {
588
+ MLIRContext *context = funcOp.getContext ();
573
589
OwningRewritePatternList patterns;
574
- patterns.insert <LinalgRewritePattern<LoopType>>();
590
+ patterns.insert <LinalgRewritePattern<LoopType>>(interchangeVector );
575
591
DimOp::getCanonicalizationPatterns (patterns, context);
576
592
AffineApplyOp::getCanonicalizationPatterns (patterns, context);
577
593
patterns.insert <FoldAffineOp>(context);
@@ -620,20 +636,20 @@ struct FoldAffineOp : public RewritePattern {
620
636
struct LowerToAffineLoops
621
637
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
622
638
void runOnFunction () override {
623
- lowerLinalgToLoopsImpl<AffineForOp>(getFunction (), & getContext () );
639
+ lowerLinalgToLoopsImpl<AffineForOp>(getFunction (), interchangeVector );
624
640
}
625
641
};
626
642
627
643
struct LowerToLoops : public LinalgLowerToLoopsBase <LowerToLoops> {
628
644
void runOnFunction () override {
629
- lowerLinalgToLoopsImpl<scf::ForOp>(getFunction (), & getContext () );
645
+ lowerLinalgToLoopsImpl<scf::ForOp>(getFunction (), interchangeVector );
630
646
}
631
647
};
632
648
633
649
struct LowerToParallelLoops
634
650
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
635
651
void runOnFunction () override {
636
- lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction (), & getContext () );
652
+ lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction (), interchangeVector );
637
653
}
638
654
};
639
655
} // namespace
@@ -654,38 +670,43 @@ mlir::createConvertLinalgToAffineLoopsPass() {
654
670
655
671
// / Emits a loop nest with the proper body for `op`.
656
672
template <typename LoopTy>
657
- Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops (OpBuilder &builder,
658
- Operation *op) {
659
- return linalgOpToLoopsImpl<LoopTy>(op, builder);
673
+ Optional<LinalgLoops>
674
+ mlir::linalg::linalgLowerOpToLoops (OpBuilder &builder, Operation *op,
675
+ ArrayRef<unsigned > interchangeVector) {
676
+ return linalgOpToLoopsImpl<LoopTy>(op, builder, interchangeVector);
660
677
}
661
678
679
+ template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<AffineForOp>(
680
+ OpBuilder &builder, Operation *op, ArrayRef<unsigned > interchangeVector);
681
+ template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(
682
+ OpBuilder &builder, Operation *op, ArrayRef<unsigned > interchangeVector);
662
683
template Optional<LinalgLoops>
663
- mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder,
664
- Operation *op);
665
- template Optional<LinalgLoops>
666
- mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder,
667
- Operation *op);
668
- template Optional<LinalgLoops>
669
- mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder,
670
- Operation *op);
684
+ mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(
685
+ OpBuilder &builder, Operation *op, ArrayRef<unsigned > interchangeVector);
671
686
672
687
// / Emits a loop nest of `affine.for` with the proper body for `op`.
673
- LogicalResult mlir::linalg::linalgOpToAffineLoops (OpBuilder &builder,
674
- Operation *op) {
675
- Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op);
688
+ LogicalResult
689
+ mlir::linalg::linalgOpToAffineLoops (OpBuilder &builder, Operation *op,
690
+ ArrayRef<unsigned > interchangeVector) {
691
+ Optional<LinalgLoops> loops =
692
+ linalgLowerOpToLoops<AffineForOp>(builder, op, interchangeVector);
676
693
return loops ? success () : failure ();
677
694
}
678
695
679
696
// / Emits a loop nest of `scf.for` with the proper body for `op`.
680
- LogicalResult mlir::linalg::linalgOpToLoops (OpBuilder &builder, Operation *op) {
681
- Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op);
697
+ LogicalResult
698
+ mlir::linalg::linalgOpToLoops (OpBuilder &builder, Operation *op,
699
+ ArrayRef<unsigned > interchangeVector) {
700
+ Optional<LinalgLoops> loops =
701
+ linalgLowerOpToLoops<scf::ForOp>(builder, op, interchangeVector);
682
702
return loops ? success () : failure ();
683
703
}
684
704
685
705
// / Emits a loop nest of `scf.parallel` with the proper body for `op`.
686
- LogicalResult mlir::linalg::linalgOpToParallelLoops (OpBuilder &builder,
687
- Operation *op) {
706
+ LogicalResult
707
+ mlir::linalg::linalgOpToParallelLoops (OpBuilder &builder, Operation *op,
708
+ ArrayRef<unsigned > interchangeVector) {
688
709
Optional<LinalgLoops> loops =
689
- linalgLowerOpToLoops<scf::ParallelOp>(builder, op);
710
+ linalgLowerOpToLoops<scf::ParallelOp>(builder, op, interchangeVector );
690
711
return loops ? success () : failure ();
691
712
}
0 commit comments