Skip to content

Commit 5522547

Browse files
committed
[mlir][linalg] Support permutation when lowering to loop nests
Linalg ops are perfect loop nests. When materializing the concrete loop nest, the default order specified by the Linalg op's iterators may not be the best for further CodeGen: targets frequently need to plan the loop order in order to gain better data access. And different targets can have different preferences. So there should exist a way to control the order. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D91795
1 parent df86f15 commit 5522547

File tree

4 files changed

+137
-50
lines changed

4 files changed

+137
-50
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def LinalgFoldUnitExtentDims : FunctionPass<"linalg-fold-unit-extent-dims"> {
2828
let options = [
2929
Option<"foldOneTripLoopsOnly", "fold-one-trip-loops-only", "bool",
3030
/*default=*/"false",
31-
"Only folds the one-trip loops from Linalg ops on tensors "
32-
"(for testing purposes only)">
31+
"Only folds the one-trip loops from Linalg ops on tensors "
32+
"(for testing purposes only)">
3333
];
3434
let dependentDialects = ["linalg::LinalgDialect"];
3535
}
@@ -52,12 +52,24 @@ def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
5252
let summary = "Lower the operations from the linalg dialect into affine "
5353
"loops";
5454
let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
55+
let options = [
56+
ListOption<"interchangeVector", "interchange-vector", "unsigned",
57+
"Permute the loops in the nest following the given "
58+
"interchange vector",
59+
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
60+
];
5561
let dependentDialects = ["linalg::LinalgDialect", "AffineDialect"];
5662
}
5763

5864
def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
5965
let summary = "Lower the operations from the linalg dialect into loops";
6066
let constructor = "mlir::createConvertLinalgToLoopsPass()";
67+
let options = [
68+
ListOption<"interchangeVector", "interchange-vector", "unsigned",
69+
"Permute the loops in the nest following the given "
70+
"interchange vector",
71+
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
72+
];
6173
let dependentDialects = ["linalg::LinalgDialect", "scf::SCFDialect", "AffineDialect"];
6274
}
6375

@@ -72,6 +84,12 @@ def LinalgLowerToParallelLoops
7284
let summary = "Lower the operations from the linalg dialect into parallel "
7385
"loops";
7486
let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
87+
let options = [
88+
ListOption<"interchangeVector", "interchange-vector", "unsigned",
89+
"Permute the loops in the nest following the given "
90+
"interchange vector",
91+
"llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
92+
];
7593
let dependentDialects = ["AffineDialect", "linalg::LinalgDialect", "scf::SCFDialect"];
7694
}
7795

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -267,16 +267,28 @@ void vectorizeLinalgOp(OpBuilder &builder, Operation *op);
267267

268268
/// Emits a loop nest of `LoopTy` with the proper body for `op`.
269269
template <typename LoopTy>
270-
Optional<LinalgLoops> linalgLowerOpToLoops(OpBuilder &builder, Operation *op);
271-
272-
/// Emits a loop nest of `scf.for` with the proper body for `op`.
273-
LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op);
274-
275-
/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
276-
LogicalResult linalgOpToParallelLoops(OpBuilder &builder, Operation *op);
270+
Optional<LinalgLoops>
271+
linalgLowerOpToLoops(OpBuilder &builder, Operation *op,
272+
ArrayRef<unsigned> interchangeVector = {});
273+
274+
/// Emits a loop nest of `scf.for` with the proper body for `op`. The generated
275+
/// loop nest will follow the `interchangeVector`-permutated iterator order. If
276+
/// `interchangeVector` is empty, then no permutation happens.
277+
LogicalResult linalgOpToLoops(OpBuilder &builder, Operation *op,
278+
ArrayRef<unsigned> interchangeVector = {});
279+
280+
/// Emits a loop nest of `scf.parallel` with the proper body for `op`. The
281+
/// generated loop nest will follow the `interchangeVector`-permutated
282+
// iterator order. If `interchangeVector` is empty, then no permutation happens.
283+
LogicalResult
284+
linalgOpToParallelLoops(OpBuilder &builder, Operation *op,
285+
ArrayRef<unsigned> interchangeVector = {});
277286

278-
/// Emits a loop nest of `affine.for` with the proper body for `op`.
279-
LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op);
287+
/// Emits a loop nest of `affine.for` with the proper body for `op`. The
288+
/// generated loop nest will follow the `interchangeVector`-permutated
289+
// iterator order. If `interchangeVector` is empty, then no permutation happens.
290+
LogicalResult linalgOpToAffineLoops(OpBuilder &builder, Operation *op,
291+
ArrayRef<unsigned> interchangeVector = {});
280292

281293
//===----------------------------------------------------------------------===//
282294
// Preconditions that ensure the corresponding transformation succeeds and can
@@ -587,13 +599,17 @@ enum class LinalgLoweringType {
587599
AffineLoops = 2,
588600
ParallelLoops = 3
589601
};
602+
590603
template <typename OpTy>
591604
struct LinalgLoweringPattern : public RewritePattern {
592605
LinalgLoweringPattern(MLIRContext *context, LinalgLoweringType loweringType,
593606
LinalgMarker marker = LinalgMarker(),
607+
ArrayRef<unsigned> interchangeVector = {},
594608
PatternBenefit benefit = 1)
595609
: RewritePattern(OpTy::getOperationName(), {}, benefit, context),
596-
marker(marker), loweringType(loweringType) {}
610+
marker(marker), loweringType(loweringType),
611+
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
612+
597613
// TODO: Move implementation to .cpp once named ops are auto-generated.
598614
LogicalResult matchAndRewrite(Operation *op,
599615
PatternRewriter &rewriter) const override {
@@ -603,18 +619,24 @@ struct LinalgLoweringPattern : public RewritePattern {
603619
if (failed(marker.checkAndNotify(rewriter, linalgOp)))
604620
return failure();
605621

606-
if (loweringType == LinalgLoweringType::LibraryCall) {
622+
switch (loweringType) {
623+
case LinalgLoweringType::LibraryCall:
607624
// TODO: Move lowering to library calls here.
608625
return failure();
609-
} else if (loweringType == LinalgLoweringType::Loops) {
610-
if (failed(linalgOpToLoops(rewriter, op)))
626+
case LinalgLoweringType::Loops:
627+
if (failed(linalgOpToLoops(rewriter, op, interchangeVector)))
611628
return failure();
612-
} else if (loweringType == LinalgLoweringType::AffineLoops) {
613-
if (failed(linalgOpToAffineLoops(rewriter, op)))
629+
break;
630+
case LinalgLoweringType::AffineLoops:
631+
if (failed(linalgOpToAffineLoops(rewriter, op, interchangeVector)))
614632
return failure();
615-
} else if (failed(linalgOpToParallelLoops(rewriter, op))) {
616-
return failure();
633+
break;
634+
case LinalgLoweringType::ParallelLoops:
635+
if (failed(linalgOpToParallelLoops(rewriter, op, interchangeVector)))
636+
return failure();
637+
break;
617638
}
639+
618640
rewriter.eraseOp(op);
619641
return success();
620642
}
@@ -625,6 +647,8 @@ struct LinalgLoweringPattern : public RewritePattern {
625647
/// Controls whether the pattern lowers to library calls, scf.for, affine.for
626648
/// or scf.parallel.
627649
LinalgLoweringType loweringType;
650+
/// Permutated loop order in the generated loop nest.
651+
SmallVector<unsigned, 4> interchangeVector;
628652
};
629653

630654
/// Linalg generalization patterns

mlir/lib/Dialect/Linalg/Transforms/Loops.cpp

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
#include "mlir/Transforms/DialectConversion.h"
2424
#include "mlir/Transforms/FoldUtils.h"
2525
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
26-
2726
#include "llvm/ADT/TypeSwitch.h"
2827

2928
using namespace mlir;
@@ -505,21 +504,31 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
505504
}
506505

507506
template <typename LoopTy>
508-
static Optional<LinalgLoops> linalgOpToLoopsImpl(Operation *op,
509-
OpBuilder &builder) {
507+
static Optional<LinalgLoops>
508+
linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
509+
ArrayRef<unsigned> interchangeVector) {
510510
using IndexedValueTy = typename GenerateLoopNest<LoopTy>::IndexedValueTy;
511-
512511
ScopedContext scope(builder, op->getLoc());
513512

514513
// The flattened loopToOperandRangesMaps is expected to be an invertible
515514
// permutation map (which is asserted in the inverse calculation).
516515
auto linalgOp = cast<LinalgOp>(op);
517516
assert(linalgOp.hasBufferSemantics() &&
518517
"expected linalg op with buffer semantics");
518+
519519
auto loopRanges = linalgOp.createLoopRanges(builder, op->getLoc());
520+
auto iteratorTypes = llvm::to_vector<4>(linalgOp.iterator_types().getValue());
521+
522+
if (!interchangeVector.empty()) {
523+
assert(interchangeVector.size() == loopRanges.size());
524+
assert(interchangeVector.size() == iteratorTypes.size());
525+
applyPermutationToVector(loopRanges, interchangeVector);
526+
applyPermutationToVector(iteratorTypes, interchangeVector);
527+
}
528+
520529
SmallVector<Value, 4> allIvs;
521530
GenerateLoopNest<LoopTy>::doit(
522-
loopRanges, /*iterInitArgs*/ {}, linalgOp.iterator_types().getValue(),
531+
loopRanges, /*iterInitArgs=*/{}, iteratorTypes,
523532
[&](ValueRange ivs, ValueRange iterArgs) -> scf::ValueVector {
524533
assert(iterArgs.empty() && "unexpected iterArgs");
525534
allIvs.append(ivs.begin(), ivs.end());
@@ -552,26 +561,33 @@ namespace {
552561
template <typename LoopType>
553562
class LinalgRewritePattern : public RewritePattern {
554563
public:
555-
LinalgRewritePattern() : RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()) {}
564+
LinalgRewritePattern(ArrayRef<unsigned> interchangeVector)
565+
: RewritePattern(/*benefit=*/1, MatchAnyOpTypeTag()),
566+
interchangeVector(interchangeVector.begin(), interchangeVector.end()) {}
556567

557568
LogicalResult matchAndRewrite(Operation *op,
558569
PatternRewriter &rewriter) const override {
559570
if (!isa<LinalgOp>(op))
560571
return failure();
561-
if (!linalgOpToLoopsImpl<LoopType>(op, rewriter))
572+
if (!linalgOpToLoopsImpl<LoopType>(op, rewriter, interchangeVector))
562573
return failure();
563574
rewriter.eraseOp(op);
564575
return success();
565576
}
577+
578+
private:
579+
SmallVector<unsigned, 4> interchangeVector;
566580
};
567581

568582
struct FoldAffineOp;
569583
} // namespace
570584

571585
template <typename LoopType>
572-
static void lowerLinalgToLoopsImpl(FuncOp funcOp, MLIRContext *context) {
586+
static void lowerLinalgToLoopsImpl(FuncOp funcOp,
587+
ArrayRef<unsigned> interchangeVector) {
588+
MLIRContext *context = funcOp.getContext();
573589
OwningRewritePatternList patterns;
574-
patterns.insert<LinalgRewritePattern<LoopType>>();
590+
patterns.insert<LinalgRewritePattern<LoopType>>(interchangeVector);
575591
DimOp::getCanonicalizationPatterns(patterns, context);
576592
AffineApplyOp::getCanonicalizationPatterns(patterns, context);
577593
patterns.insert<FoldAffineOp>(context);
@@ -620,20 +636,20 @@ struct FoldAffineOp : public RewritePattern {
620636
struct LowerToAffineLoops
621637
: public LinalgLowerToAffineLoopsBase<LowerToAffineLoops> {
622638
void runOnFunction() override {
623-
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), &getContext());
639+
lowerLinalgToLoopsImpl<AffineForOp>(getFunction(), interchangeVector);
624640
}
625641
};
626642

627643
struct LowerToLoops : public LinalgLowerToLoopsBase<LowerToLoops> {
628644
void runOnFunction() override {
629-
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), &getContext());
645+
lowerLinalgToLoopsImpl<scf::ForOp>(getFunction(), interchangeVector);
630646
}
631647
};
632648

633649
struct LowerToParallelLoops
634650
: public LinalgLowerToParallelLoopsBase<LowerToParallelLoops> {
635651
void runOnFunction() override {
636-
lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), &getContext());
652+
lowerLinalgToLoopsImpl<scf::ParallelOp>(getFunction(), interchangeVector);
637653
}
638654
};
639655
} // namespace
@@ -654,38 +670,43 @@ mlir::createConvertLinalgToAffineLoopsPass() {
654670

655671
/// Emits a loop nest with the proper body for `op`.
656672
template <typename LoopTy>
657-
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder,
658-
Operation *op) {
659-
return linalgOpToLoopsImpl<LoopTy>(op, builder);
673+
Optional<LinalgLoops>
674+
mlir::linalg::linalgLowerOpToLoops(OpBuilder &builder, Operation *op,
675+
ArrayRef<unsigned> interchangeVector) {
676+
return linalgOpToLoopsImpl<LoopTy>(op, builder, interchangeVector);
660677
}
661678

679+
template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<AffineForOp>(
680+
OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
681+
template Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(
682+
OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
662683
template Optional<LinalgLoops>
663-
mlir::linalg::linalgLowerOpToLoops<AffineForOp>(OpBuilder &builder,
664-
Operation *op);
665-
template Optional<LinalgLoops>
666-
mlir::linalg::linalgLowerOpToLoops<scf::ForOp>(OpBuilder &builder,
667-
Operation *op);
668-
template Optional<LinalgLoops>
669-
mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(OpBuilder &builder,
670-
Operation *op);
684+
mlir::linalg::linalgLowerOpToLoops<scf::ParallelOp>(
685+
OpBuilder &builder, Operation *op, ArrayRef<unsigned> interchangeVector);
671686

672687
/// Emits a loop nest of `affine.for` with the proper body for `op`.
673-
LogicalResult mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder,
674-
Operation *op) {
675-
Optional<LinalgLoops> loops = linalgLowerOpToLoops<AffineForOp>(builder, op);
688+
LogicalResult
689+
mlir::linalg::linalgOpToAffineLoops(OpBuilder &builder, Operation *op,
690+
ArrayRef<unsigned> interchangeVector) {
691+
Optional<LinalgLoops> loops =
692+
linalgLowerOpToLoops<AffineForOp>(builder, op, interchangeVector);
676693
return loops ? success() : failure();
677694
}
678695

679696
/// Emits a loop nest of `scf.for` with the proper body for `op`.
680-
LogicalResult mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op) {
681-
Optional<LinalgLoops> loops = linalgLowerOpToLoops<scf::ForOp>(builder, op);
697+
LogicalResult
698+
mlir::linalg::linalgOpToLoops(OpBuilder &builder, Operation *op,
699+
ArrayRef<unsigned> interchangeVector) {
700+
Optional<LinalgLoops> loops =
701+
linalgLowerOpToLoops<scf::ForOp>(builder, op, interchangeVector);
682702
return loops ? success() : failure();
683703
}
684704

685705
/// Emits a loop nest of `scf.parallel` with the proper body for `op`.
686-
LogicalResult mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder,
687-
Operation *op) {
706+
LogicalResult
707+
mlir::linalg::linalgOpToParallelLoops(OpBuilder &builder, Operation *op,
708+
ArrayRef<unsigned> interchangeVector) {
688709
Optional<LinalgLoops> loops =
689-
linalgLowerOpToLoops<scf::ParallelOp>(builder, op);
710+
linalgLowerOpToLoops<scf::ParallelOp>(builder, op, interchangeVector);
690711
return loops ? success() : failure();
691712
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: mlir-opt %s -convert-linalg-to-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=LOOP %s
2+
// RUN: mlir-opt %s -convert-linalg-to-parallel-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=PARALLEL %s
3+
// RUN: mlir-opt %s -convert-linalg-to-affine-loops="interchange-vector=4,0,3,1,2" | FileCheck --check-prefix=AFFINE %s
4+
5+
func @copy(%input: memref<1x2x3x4x5xf32>, %output: memref<1x2x3x4x5xf32>) {
6+
linalg.copy(%input, %output): memref<1x2x3x4x5xf32>, memref<1x2x3x4x5xf32>
7+
return
8+
}
9+
10+
// LOOP: scf.for %{{.*}} = %c0 to %c5 step %c1
11+
// LOOP: scf.for %{{.*}} = %c0 to %c1 step %c1
12+
// LOOP: scf.for %{{.*}} = %c0 to %c4 step %c1
13+
// LOOP: scf.for %{{.*}} = %c0 to %c2 step %c1
14+
// LOOP: scf.for %{{.*}} = %c0 to %c3 step %c1
15+
16+
// PARALLEL: scf.parallel
17+
// PARALLEL-SAME: to (%c5, %c1, %c4, %c2, %c3)
18+
19+
// AFFINE: affine.for %{{.*}} = 0 to 5
20+
// AFFINE: affine.for %{{.*}} = 0 to 1
21+
// AFFINE: affine.for %{{.*}} = 0 to 4
22+
// AFFINE: affine.for %{{.*}} = 0 to 2
23+
// AFFINE: affine.for %{{.*}} = 0 to 3
24+

0 commit comments

Comments
 (0)