|
19 | 19 | using namespace mlir; |
20 | 20 |
|
21 | 21 | namespace { |
22 | | -/// Move an scf.if op that is directly before the scf.condition op in the while |
23 | | -/// before region, and whose condition matches the condition of the |
24 | | -/// scf.condition op, down into the while after region. |
25 | | -/// |
26 | | -/// scf.while (%init) : (...) -> ... { |
27 | | -/// %cond = ... |
28 | | -/// %res = scf.if %cond -> (...) { |
29 | | -/// use1(%init) |
30 | | -/// %then_val = ... |
31 | | -/// ... // then block |
32 | | -/// scf.yield %then_val |
33 | | -/// } else { |
34 | | -/// scf.yield %init |
35 | | -/// } |
36 | | -/// scf.condition(%cond) %res |
37 | | -/// } do { |
38 | | -/// ^bb0(%arg): |
39 | | -/// use2(%arg) |
40 | | -/// ... |
41 | | -/// |
42 | | -/// becomes |
43 | | -/// scf.while (%init) : (...) -> ... { |
44 | | -/// %cond = ... |
45 | | -/// scf.condition(%cond) %init |
46 | | -/// } do { |
47 | | -/// ^bb0(%arg): : |
48 | | -/// use1(%arg) |
49 | | -/// ... // if then block |
50 | | -/// %then_val = ... |
51 | | -/// use2(%then_val) |
52 | | -/// ... |
53 | | -struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> { |
54 | | - using OpRewritePattern<scf::WhileOp>::OpRewritePattern; |
55 | | - |
56 | | - LogicalResult matchAndRewrite(scf::WhileOp op, |
57 | | - PatternRewriter &rewriter) const override { |
58 | | - // Check that the first opeation produces one result and that result must |
59 | | - // have exactly two uses (these two uses come from the `scf.if` and |
60 | | - // `scf.condition` operations). |
61 | | - Operation &condOp = op.getBeforeBody()->front(); |
62 | | - if (condOp.getNumResults() != 1 || !condOp.getResult(0).hasNUses(2)) |
63 | | - return failure(); |
64 | | - |
65 | | - Value condVal = condOp.getResult(0); |
66 | | - auto ifOp = dyn_cast<scf::IfOp>(condOp.getNextNode()); |
67 | | - if (!ifOp || ifOp.getCondition() != condVal) |
68 | | - return failure(); |
69 | | - |
70 | | - auto term = dyn_cast<scf::ConditionOp>(ifOp->getNextNode()); |
71 | | - if (!term || term.getCondition() != condVal) |
72 | | - return failure(); |
73 | | - |
74 | | - // Check that if results and else yield operands match the scf.condition op |
75 | | - // arguments and while before arguments respectively. |
76 | | - if (!llvm::equal(ifOp->getResults(), term.getArgs()) || |
77 | | - !llvm::equal(ifOp.elseYield()->getOperands(), op.getBeforeArguments())) |
78 | | - return failure(); |
79 | | - |
80 | | - // Update uses and move the if op into the after region. |
81 | | - rewriter.replaceAllUsesWith(op.getAfterArguments(), |
82 | | - ifOp.thenYield()->getOperands()); |
83 | | - rewriter.replaceUsesWithIf(op.getBeforeArguments(), op.getAfterArguments(), |
84 | | - [&](OpOperand &use) { |
85 | | - return ifOp.getThenRegion().isAncestor( |
86 | | - use.getOwner()->getParentRegion()); |
87 | | - }); |
88 | | - rewriter.modifyOpInPlace( |
89 | | - term, [&]() { term.getArgsMutable().assign(op.getBeforeArguments()); }); |
90 | | - |
91 | | - rewriter.eraseOp(ifOp.thenYield()); |
92 | | - rewriter.inlineBlockBefore(ifOp.thenBlock(), op.getAfterBody(), |
93 | | - op.getAfterBody()->begin()); |
94 | | - rewriter.eraseOp(ifOp); |
95 | | - return success(); |
96 | | - } |
97 | | -}; |
98 | | - |
99 | 22 | struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> { |
100 | 23 | using OpRewritePattern::OpRewritePattern; |
101 | 24 |
|
@@ -344,5 +267,5 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, |
344 | 267 | } |
345 | 268 |
|
346 | 269 | void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) { |
347 | | - patterns.add<UpliftWhileOp, WhileMoveIfDown>(patterns.getContext()); |
| 270 | + patterns.add<UpliftWhileOp>(patterns.getContext()); |
348 | 271 | } |
0 commit comments