Skip to content

Commit 95e0ae9

Browse files
gmalasanaartbik
andauthored
[MLIR][SparseTensor] Loop ordering strategy infrastructure (flag) (#154656)
As discussed before, this PR adds the basic infrastructure/boiler plate for loop ordering strategies to be implemented. If this looks ok, I wanted to also mention some of the heuristics that I would implement next, if they sound reasonable to you guys: - Parallel first : prioritize parallel loops over reduction loops - Dense outer : prioritize the most dense loops first - Sparse outer : the opposite, potentially useful in some cases? There is another that I am considering, stride/memory aware, which would prioritize loops with better stride patterns (like sequential or linear). Not sure how well this carries over to Sparse Tensor though. Are there any ideas/heuristics that I should definitely try to implement? As we discussed, I will try to incrementally add heuristics. Sorry for the delay on my end, and thank you so much for the feedback! --------- Co-authored-by: Aart Bik <[email protected]>
1 parent 8889377 commit 95e0ae9

File tree

6 files changed

+77
-21
lines changed

6 files changed

+77
-21
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,16 @@ enum class SparseEmitStrategy {
5555
kDebugInterface, // generate only place-holder for sparse iteration
5656
};
5757

58+
namespace sparse_tensor {
59+
60+
/// Defines a strategy for loop ordering during sparse code generation.
61+
enum class LoopOrderingStrategy : unsigned {
62+
kDefault, ///< Default strategy (eagerly selects last loop in topological
63+
///< sort).
64+
};
65+
66+
} // namespace sparse_tensor
67+
5868
#define GEN_PASS_DECL
5969
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h.inc"
6070

@@ -71,11 +81,16 @@ std::unique_ptr<Pass> createSparseAssembler(bool directOut);
7181
// The SparseReinterpretMap pass.
7282
//===----------------------------------------------------------------------===//
7383

74-
void populateSparseReinterpretMap(RewritePatternSet &patterns,
75-
ReinterpretMapScope scope);
84+
void populateSparseReinterpretMap(
85+
RewritePatternSet &patterns, ReinterpretMapScope scope,
86+
sparse_tensor::LoopOrderingStrategy strategy =
87+
sparse_tensor::LoopOrderingStrategy::kDefault);
7688

7789
std::unique_ptr<Pass> createSparseReinterpretMapPass();
7890
std::unique_ptr<Pass> createSparseReinterpretMapPass(ReinterpretMapScope scope);
91+
std::unique_ptr<Pass>
92+
createSparseReinterpretMapPass(ReinterpretMapScope scope,
93+
sparse_tensor::LoopOrderingStrategy strategy);
7994

8095
//===----------------------------------------------------------------------===//
8196
// The PreSparsificationRewriting pass.

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,11 @@ def SparseReinterpretMap : Pass<"sparse-reinterpret-map", "ModuleOp"> {
8181
clEnumValN(mlir::ReinterpretMapScope::kExceptGeneric,
8282
"except-generic",
8383
"Run on operations expect linalg.generic (e.g., foreach)"))}]>,
84+
Option<"loopOrderingStrategy", "loop-ordering-strategy", "mlir::sparse_tensor::LoopOrderingStrategy",
85+
"mlir::sparse_tensor::LoopOrderingStrategy::kDefault",
86+
"Set the loop ordering strategy for sparse code generation", [{llvm::cl::values(
87+
clEnumValN(mlir::sparse_tensor::LoopOrderingStrategy::kDefault, "default",
88+
"Default strategy (eagerly selects last loop in topological sort)"))}]>,
8489
];
8590
}
8691

mlir/lib/Dialect/SparseTensor/Transforms/SparseReinterpretMap.cpp

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ struct DemapInsRewriter : public OpRewritePattern<SourceOp> {
5959

6060
// Flattens an affine expression into a list of AffineDimExprs.
6161
struct AffineDimCollector : public AffineExprVisitor<AffineDimCollector> {
62-
explicit AffineDimCollector(unsigned dimNum) : dims(dimNum){};
62+
explicit AffineDimCollector(unsigned dimNum) : dims(dimNum) {};
6363
void visitDimExpr(AffineDimExpr expr) { dims.set(expr.getPosition()); }
6464
BitVector dims;
6565
};
6666

6767
// Flattens an affine expression into a list of AffineDimExprs.
6868
struct AffineExprAdmissibleVisitor
6969
: public AffineExprVisitor<AffineExprAdmissibleVisitor> {
70-
explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput){};
70+
explicit AffineExprAdmissibleVisitor(bool isOutput) : isOutput(isOutput) {};
7171

7272
// We only allow AffineDimExpr on output.
7373
void visitAddExpr(AffineBinaryOpExpr expr) {
@@ -407,7 +407,10 @@ struct GenericOpReinterpretMap
407407
};
408408

409409
struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
410-
using OpRewritePattern::OpRewritePattern;
410+
GenericOpScheduler(MLIRContext *context,
411+
sparse_tensor::LoopOrderingStrategy strategy)
412+
: OpRewritePattern<linalg::GenericOp>(context), strategy(strategy) {}
413+
411414
LogicalResult matchAndRewrite(linalg::GenericOp linalgOp,
412415
PatternRewriter &rewriter) const override {
413416
if (linalgOp.getNumDpsInits() != 1 || !linalgOp.hasPureTensorSemantics() ||
@@ -420,7 +423,8 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
420423
if (linalgOp->hasAttr(sorted))
421424
return failure();
422425

423-
auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp);
426+
// Pass strategy to IterationGraphSorter.
427+
auto scheduler = IterationGraphSorter::fromGenericOp(linalgOp, strategy);
424428
bool isAdmissible = false;
425429
AffineMap order;
426430
// A const list of all masks that we used for iteration graph
@@ -582,6 +586,9 @@ struct GenericOpScheduler : public OpRewritePattern<linalg::GenericOp> {
582586
// TODO: convert more than one?
583587
return failure();
584588
}
589+
590+
private:
591+
sparse_tensor::LoopOrderingStrategy strategy;
585592
};
586593

587594
//===----------------------------------------------------------------------===//
@@ -786,12 +793,13 @@ struct ForeachOpDemapper
786793

787794
} // namespace
788795

789-
void mlir::populateSparseReinterpretMap(RewritePatternSet &patterns,
790-
ReinterpretMapScope scope) {
796+
void mlir::populateSparseReinterpretMap(
797+
RewritePatternSet &patterns, ReinterpretMapScope scope,
798+
sparse_tensor::LoopOrderingStrategy strategy) {
791799
if (scope == ReinterpretMapScope::kAll ||
792800
scope == ReinterpretMapScope::kGenericOnly) {
793-
patterns.add<GenericOpReinterpretMap, GenericOpScheduler>(
794-
patterns.getContext());
801+
patterns.add<GenericOpReinterpretMap>(patterns.getContext());
802+
patterns.add<GenericOpScheduler>(patterns.getContext(), strategy);
795803
}
796804
if (scope == ReinterpretMapScope::kAll ||
797805
scope == ReinterpretMapScope::kExceptGeneric) {

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorPasses.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,13 @@ struct SparseReinterpretMap
6767
SparseReinterpretMap(const SparseReinterpretMap &pass) = default;
6868
SparseReinterpretMap(const SparseReinterpretMapOptions &options) {
6969
scope = options.scope;
70+
loopOrderingStrategy = options.loopOrderingStrategy;
7071
}
7172

7273
void runOnOperation() override {
7374
auto *ctx = &getContext();
7475
RewritePatternSet patterns(ctx);
75-
populateSparseReinterpretMap(patterns, scope);
76+
populateSparseReinterpretMap(patterns, scope, loopOrderingStrategy);
7677
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
7778
}
7879
};
@@ -438,6 +439,14 @@ mlir::createSparseReinterpretMapPass(ReinterpretMapScope scope) {
438439
return std::make_unique<SparseReinterpretMap>(options);
439440
}
440441

442+
std::unique_ptr<Pass> mlir::createSparseReinterpretMapPass(
443+
ReinterpretMapScope scope, sparse_tensor::LoopOrderingStrategy strategy) {
444+
SparseReinterpretMapOptions options;
445+
options.scope = scope;
446+
options.loopOrderingStrategy = strategy;
447+
return std::make_unique<SparseReinterpretMap>(options);
448+
}
449+
441450
std::unique_ptr<Pass> mlir::createPreSparsificationRewritePass() {
442451
return std::make_unique<PreSparsificationRewritePass>();
443452
}

mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,15 @@ AffineMap IterationGraphSorter::topoSort() {
100100
// We always prefer a parallel loop over a reduction loop because putting
101101
// a reduction loop early might make the loop sequence inadmissible.
102102
auto &it = !parIt.empty() ? parIt : redIt;
103-
auto src = it.back();
103+
104+
// Select loop based on strategy.
105+
unsigned src;
106+
switch (strategy) {
107+
case sparse_tensor::LoopOrderingStrategy::kDefault:
108+
src = it.back();
109+
break;
110+
}
111+
104112
loopOrder.push_back(src);
105113
it.pop_back();
106114
// Update in-degree, and push 0-degree node into worklist.
@@ -122,8 +130,8 @@ AffineMap IterationGraphSorter::topoSort() {
122130
return AffineMap();
123131
}
124132

125-
IterationGraphSorter
126-
IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
133+
IterationGraphSorter IterationGraphSorter::fromGenericOp(
134+
linalg::GenericOp genericOp, sparse_tensor::LoopOrderingStrategy strategy) {
127135
// Must be a demapped sparse kernel.
128136
assert(!hasAnyNonIdentityOperandsOrResults(genericOp) &&
129137
hasAnySparseOperandOrResult(genericOp) &&
@@ -140,14 +148,16 @@ IterationGraphSorter::fromGenericOp(linalg::GenericOp genericOp) {
140148
genericOp.getIteratorTypesArray();
141149

142150
return IterationGraphSorter(std::move(ins), std::move(loopMap), out, outMap,
143-
std::move(iterTypes));
151+
std::move(iterTypes), strategy);
144152
}
145153

146154
IterationGraphSorter::IterationGraphSorter(
147155
SmallVector<Value> &&ins, SmallVector<AffineMap> &&loop2InsLvl, Value out,
148-
AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes)
156+
AffineMap loop2OutLvl, SmallVector<utils::IteratorType> &&iterTypes,
157+
sparse_tensor::LoopOrderingStrategy strategy)
149158
: ins(std::move(ins)), loop2InsLvl(std::move(loop2InsLvl)), out(out),
150-
loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)) {
159+
loop2OutLvl(loop2OutLvl), iterTypes(std::move(iterTypes)),
160+
strategy(strategy) {
151161
// One map per tensor.
152162
assert(loop2InsLvl.size() == ins.size());
153163
// All the affine maps have the same number of dimensions (loops).

mlir/lib/Dialect/SparseTensor/Transforms/Utils/IterationGraphSorter.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#ifndef MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
1414
#define MLIR_DIALECT_SPARSETENSOR_TRANSFORMS_UTILS_ITERATIONGRAPHSORTER_H_
1515

16+
#include "mlir/Dialect/SparseTensor/Transforms/Passes.h"
1617
#include "mlir/IR/AffineMap.h"
1718

1819
namespace mlir {
@@ -41,9 +42,12 @@ enum class SortMask : unsigned {
4142

4243
class IterationGraphSorter {
4344
public:
44-
/// Factory method that construct an iteration graph sorter
45-
/// for the given linalg.generic operation.
46-
static IterationGraphSorter fromGenericOp(linalg::GenericOp genericOp);
45+
/// Factory method that constructs an iteration graph sorter
46+
/// for the given linalg.generic operation with a specific loop ordering
47+
/// strategy.
48+
static IterationGraphSorter
49+
fromGenericOp(linalg::GenericOp genericOp,
50+
sparse_tensor::LoopOrderingStrategy strategy);
4751

4852
/// Returns a permutation that represents the scheduled loop order.
4953
/// Note that the returned AffineMap could be null if the kernel
@@ -58,7 +62,9 @@ class IterationGraphSorter {
5862
IterationGraphSorter(SmallVector<Value> &&ins,
5963
SmallVector<AffineMap> &&loop2InsLvl, Value out,
6064
AffineMap loop2OutLvl,
61-
SmallVector<utils::IteratorType> &&iterTypes);
65+
SmallVector<utils::IteratorType> &&iterTypes,
66+
sparse_tensor::LoopOrderingStrategy strategy =
67+
sparse_tensor::LoopOrderingStrategy::kDefault);
6268

6369
// Adds all the constraints in the given loop to level map.
6470
void addConstraints(Value t, AffineMap loop2LvlMap);
@@ -84,6 +90,9 @@ class IterationGraphSorter {
8490

8591
// InDegree used for topo sort.
8692
std::vector<unsigned> inDegree;
93+
94+
// Loop ordering strategy.
95+
sparse_tensor::LoopOrderingStrategy strategy;
8796
};
8897

8998
} // namespace sparse_tensor

0 commit comments

Comments
 (0)