@@ -21,16 +21,18 @@ namespace {
21
21
struct InnerSerialization : public InnerSerializationBase <InnerSerialization> {
22
22
void runOnOperation () override ;
23
23
};
24
+ struct Serialization : public SerializationBase <Serialization> {
25
+ void runOnOperation () override ;
26
+ };
24
27
} // namespace
25
28
26
29
struct ParSerialize : public OpRewritePattern <scf::ParallelOp> {
27
30
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
28
31
29
32
LogicalResult matchAndRewrite (scf::ParallelOp nextParallel,
30
33
PatternRewriter &rewriter) const override {
31
- if (!(nextParallel->getParentOfType <scf::ParallelOp>()
32
- // || nextParallel->getParentOfType<AffineParallelOp>()
33
- ))
34
+ if (!(nextParallel->getParentOfType <scf::ParallelOp>() ||
35
+ nextParallel->getParentOfType <AffineParallelOp>()))
34
36
return failure ();
35
37
36
38
SmallVector<Value> inds;
@@ -53,6 +55,31 @@ struct ParSerialize : public OpRewritePattern<scf::ParallelOp> {
53
55
}
54
56
};
55
57
58
+ struct Serialize : public OpRewritePattern <scf::ParallelOp> {
59
+ using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
60
+
61
+ LogicalResult matchAndRewrite (scf::ParallelOp nextParallel,
62
+ PatternRewriter &rewriter) const override {
63
+ SmallVector<Value> inds;
64
+ scf::ForOp last = nullptr ;
65
+ for (auto tup :
66
+ llvm::zip (nextParallel.getLowerBound (), nextParallel.getUpperBound (),
67
+ nextParallel.getStep (), nextParallel.getInductionVars ())) {
68
+ last =
69
+ rewriter.create <scf::ForOp>(nextParallel.getLoc (), std::get<0 >(tup),
70
+ std::get<1 >(tup), std::get<2 >(tup));
71
+ inds.push_back (last.getInductionVar ());
72
+ rewriter.setInsertionPointToStart (last.getBody ());
73
+ }
74
+ rewriter.eraseOp (last.getBody ()->getTerminator ());
75
+ rewriter.mergeBlocks (&nextParallel.getRegion ().front (), last.getBody (),
76
+ inds);
77
+
78
+ rewriter.eraseOp (nextParallel);
79
+ return success ();
80
+ }
81
+ };
82
+
56
83
void InnerSerialization::runOnOperation () {
57
84
mlir::RewritePatternSet rpl (getOperation ()->getContext ());
58
85
rpl.add <ParSerialize>(getOperation ()->getContext ());
@@ -61,6 +88,17 @@ void InnerSerialization::runOnOperation() {
61
88
(void )applyPatternsAndFoldGreedily (getOperation (), std::move (rpl), config);
62
89
}
63
90
91
+ void Serialization::runOnOperation () {
92
+ mlir::RewritePatternSet rpl (getOperation ()->getContext ());
93
+ rpl.add <Serialize>(getOperation ()->getContext ());
94
+ GreedyRewriteConfig config;
95
+ config.maxIterations = 47 ;
96
+ (void )applyPatternsAndFoldGreedily (getOperation (), std::move (rpl), config);
97
+ }
98
+
64
99
std::unique_ptr<Pass> mlir::polygeist::createInnerSerializationPass () {
65
100
return std::make_unique<InnerSerialization>();
66
101
}
102
+ std::unique_ptr<Pass> mlir::polygeist::createSerializationPass () {
103
+ return std::make_unique<Serialization>();
104
+ }
0 commit comments