Skip to content

Commit fc15676

Browse files
wsmosesivanradanov
andauthored
Enable early inner serialization pass and more aggressive serial LICM (#252)
* Early inner serialization * Add loweraffine * Add serial licm improvement * Fix yield bug * Simplify * Fix bug and simplify * Add fallback serialization * fix typo * add ser pass * Consider speculation branch * Correct isSpec * Expand if sink * Fix bug where BarrierElim removes barrier required for interchange * Fix format Co-authored-by: Ivan Radanov Ivanov <[email protected]>
1 parent 8c6545d commit fc15676

File tree

11 files changed

+811
-131
lines changed

11 files changed

+811
-131
lines changed

include/polygeist/Passes/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ std::unique_ptr<Pass> createParallelLICMPass();
1313
std::unique_ptr<Pass> createMem2RegPass();
1414
std::unique_ptr<Pass> createLoopRestructurePass();
1515
std::unique_ptr<Pass> createInnerSerializationPass();
16+
std::unique_ptr<Pass> createSerializationPass();
1617
std::unique_ptr<Pass> replaceAffineCFGPass();
1718
std::unique_ptr<Pass> createOpenMPOptPass();
1819
std::unique_ptr<Pass> createCanonicalizeForPass();

include/polygeist/Passes/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ def InnerSerialization : Pass<"inner-serialize"> {
4141
let dependentDialects =
4242
["memref::MemRefDialect", "func::FuncDialect", "LLVM::LLVMDialect"];
4343
}
44+
def Serialization : Pass<"serialize"> {
45+
let summary = "remove scf.barrier";
46+
let constructor = "mlir::polygeist::createSerializationPass()";
47+
let dependentDialects =
48+
["memref::MemRefDialect", "func::FuncDialect", "LLVM::LLVMDialect"];
49+
}
4450

4551
def SCFBarrierRemovalContinuation : InterfacePass<"barrier-removal-continuation", "FunctionOpInterface"> {
4652
let summary = "Remove scf.barrier using continuations";

lib/polygeist/Passes/InnerSerialization.cpp

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,18 @@ namespace {
2121
struct InnerSerialization : public InnerSerializationBase<InnerSerialization> {
2222
void runOnOperation() override;
2323
};
24+
struct Serialization : public SerializationBase<Serialization> {
25+
void runOnOperation() override;
26+
};
2427
} // namespace
2528

2629
struct ParSerialize : public OpRewritePattern<scf::ParallelOp> {
2730
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
2831

2932
LogicalResult matchAndRewrite(scf::ParallelOp nextParallel,
3033
PatternRewriter &rewriter) const override {
31-
if (!(nextParallel->getParentOfType<scf::ParallelOp>()
32-
// || nextParallel->getParentOfType<AffineParallelOp>()
33-
))
34+
if (!(nextParallel->getParentOfType<scf::ParallelOp>() ||
35+
nextParallel->getParentOfType<AffineParallelOp>()))
3436
return failure();
3537

3638
SmallVector<Value> inds;
@@ -53,6 +55,31 @@ struct ParSerialize : public OpRewritePattern<scf::ParallelOp> {
5355
}
5456
};
5557

58+
struct Serialize : public OpRewritePattern<scf::ParallelOp> {
59+
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
60+
61+
LogicalResult matchAndRewrite(scf::ParallelOp nextParallel,
62+
PatternRewriter &rewriter) const override {
63+
SmallVector<Value> inds;
64+
scf::ForOp last = nullptr;
65+
for (auto tup :
66+
llvm::zip(nextParallel.getLowerBound(), nextParallel.getUpperBound(),
67+
nextParallel.getStep(), nextParallel.getInductionVars())) {
68+
last =
69+
rewriter.create<scf::ForOp>(nextParallel.getLoc(), std::get<0>(tup),
70+
std::get<1>(tup), std::get<2>(tup));
71+
inds.push_back(last.getInductionVar());
72+
rewriter.setInsertionPointToStart(last.getBody());
73+
}
74+
rewriter.eraseOp(last.getBody()->getTerminator());
75+
rewriter.mergeBlocks(&nextParallel.getRegion().front(), last.getBody(),
76+
inds);
77+
78+
rewriter.eraseOp(nextParallel);
79+
return success();
80+
}
81+
};
82+
5683
void InnerSerialization::runOnOperation() {
5784
mlir::RewritePatternSet rpl(getOperation()->getContext());
5885
rpl.add<ParSerialize>(getOperation()->getContext());
@@ -61,6 +88,17 @@ void InnerSerialization::runOnOperation() {
6188
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(rpl), config);
6289
}
6390

91+
void Serialization::runOnOperation() {
92+
mlir::RewritePatternSet rpl(getOperation()->getContext());
93+
rpl.add<Serialize>(getOperation()->getContext());
94+
GreedyRewriteConfig config;
95+
config.maxIterations = 47;
96+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(rpl), config);
97+
}
98+
6499
std::unique_ptr<Pass> mlir::polygeist::createInnerSerializationPass() {
65100
return std::make_unique<InnerSerialization>();
66101
}
102+
std::unique_ptr<Pass> mlir::polygeist::createSerializationPass() {
103+
return std::make_unique<Serialization>();
104+
}

lib/polygeist/Passes/OpenMPOpt.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ bool isReadNone(Operation *op) {
100100
return false;
101101
}
102102

103+
Value getBase(Value v);
104+
bool isStackAlloca(Value v);
105+
bool isCaptured(Value v, Operation *potentialUser = nullptr,
106+
bool *seenuse = nullptr);
107+
103108
bool mayReadFrom(Operation *op, Value val) {
104109
bool hasRecursiveEffects = op->hasTrait<OpTrait::HasRecursiveSideEffects>();
105110
if (hasRecursiveEffects) {
@@ -128,14 +133,16 @@ bool mayReadFrom(Operation *op, Value val) {
128133
}
129134
return false;
130135
}
136+
if (isa<LLVM::CallOp, func::CallOp>(op)) {
137+
auto base = getBase(val);
138+
bool seenuse = false;
139+
if (isStackAlloca(base) && !isCaptured(base, op, &seenuse) && !seenuse) {
140+
return false;
141+
}
142+
}
131143
return true;
132144
}
133145

134-
Value getBase(Value v);
135-
bool isStackAlloca(Value v);
136-
bool isCaptured(Value v, Operation *potentialUser = nullptr,
137-
bool *seenuse = nullptr);
138-
139146
bool mayWriteTo(Operation *op, Value val, bool ignoreBarrier) {
140147
bool hasRecursiveEffects = op->hasTrait<OpTrait::HasRecursiveSideEffects>();
141148
if (hasRecursiveEffects) {

0 commit comments

Comments
 (0)