|
15 | 15 | #include "mlir/Dialect/SCF/Transforms/Patterns.h" |
16 | 16 | #include "mlir/IR/Dominance.h" |
17 | 17 | #include "mlir/IR/PatternMatch.h" |
18 | | -#include "mlir/Transforms/RegionUtils.h" |
19 | 18 |
|
20 | 19 | using namespace mlir; |
21 | 20 |
|
22 | 21 | namespace { |
23 | | -/// Move a scf.if op that is directly before the scf.condition op in the while |
24 | | -/// before region, and whose condition matches the condition of the |
25 | | -/// scf.condition op, down into the while after region. |
26 | | -/// |
27 | | -/// scf.while (..) : (...) -> ... { |
28 | | -/// %additional_used_values = ... |
29 | | -/// %cond = ... |
30 | | -/// ... |
31 | | -/// %res = scf.if %cond -> (...) { |
32 | | -/// use(%additional_used_values) |
33 | | -/// ... // then block |
34 | | -/// scf.yield %then_value |
35 | | -/// } else { |
36 | | -/// scf.yield %else_value |
37 | | -/// } |
38 | | -/// scf.condition(%cond) %res, ... |
39 | | -/// } do { |
40 | | -/// ^bb0(%res_arg, ...): |
41 | | -/// use(%res_arg) |
42 | | -/// ... |
43 | | -/// |
44 | | -/// becomes |
45 | | -/// scf.while (..) : (...) -> ... { |
46 | | -/// %additional_used_values = ... |
47 | | -/// %cond = ... |
48 | | -/// ... |
49 | | -/// scf.condition(%cond) %else_value, ..., %additional_used_values |
50 | | -/// } do { |
51 | | -/// ^bb0(%res_arg ..., %additional_args): : |
52 | | -/// use(%additional_args) |
53 | | -/// ... // if then block |
54 | | -/// use(%then_value) |
55 | | -/// ... |
56 | | -struct WhileMoveIfDown : public OpRewritePattern<scf::WhileOp> { |
57 | | - using OpRewritePattern<scf::WhileOp>::OpRewritePattern; |
58 | | - |
59 | | - LogicalResult matchAndRewrite(scf::WhileOp op, |
60 | | - PatternRewriter &rewriter) const override { |
61 | | - auto conditionOp = |
62 | | - cast<scf::ConditionOp>(op.getBeforeBody()->getTerminator()); |
63 | | - auto ifOp = dyn_cast_or_null<scf::IfOp>(conditionOp->getPrevNode()); |
64 | | - |
65 | | - // Check that the ifOp is directly before the conditionOp and that it |
66 | | - // matches the condition of the conditionOp. Also ensure that the ifOp has |
67 | | - // no else block with content, as that would complicate the transformation. |
68 | | - // TODO: support else blocks with content. |
69 | | - if (!ifOp || ifOp.getCondition() != conditionOp.getCondition() || |
70 | | - (ifOp.elseBlock() && !ifOp.elseBlock()->without_terminator().empty())) |
71 | | - return failure(); |
72 | | - |
73 | | - assert(ifOp->use_empty() || (llvm::all_equal(ifOp->getUsers()) && |
74 | | - *ifOp->user_begin() == conditionOp) && |
75 | | - "ifOp has unexpected uses"); |
76 | | - |
77 | | - Location loc = op.getLoc(); |
78 | | - |
79 | | - // Replace uses of ifOp results in the conditionOp with the yielded values |
80 | | - // from the ifOp branches. |
81 | | - for (auto [idx, arg] : llvm::enumerate(conditionOp.getArgs())) { |
82 | | - auto it = llvm::find(ifOp->getResults(), arg); |
83 | | - if (it != ifOp->getResults().end()) { |
84 | | - size_t ifOpIdx = it.getIndex(); |
85 | | - Value thenValue = ifOp.thenYield()->getOperand(ifOpIdx); |
86 | | - Value elseValue = ifOp.elseYield()->getOperand(ifOpIdx); |
87 | | - |
88 | | - rewriter.replaceAllUsesWith(ifOp->getResults()[ifOpIdx], elseValue); |
89 | | - rewriter.replaceAllUsesWith(op.getAfterArguments()[idx], thenValue); |
90 | | - } |
91 | | - } |
92 | | - |
93 | | - // Collect additional used values from before region. |
94 | | - SetVector<Value> additionalUsedValues; |
95 | | - visitUsedValuesDefinedAbove(ifOp.getThenRegion(), [&](OpOperand *operand) { |
96 | | - if (op.getBefore().isAncestor(operand->get().getParentRegion())) |
97 | | - additionalUsedValues.insert(operand->get()); |
98 | | - }); |
99 | | - |
100 | | - // Create new whileOp with additional used values as results. |
101 | | - auto additionalValueTypes = llvm::map_to_vector( |
102 | | - additionalUsedValues, [](Value val) { return val.getType(); }); |
103 | | - size_t additionalValueSize = additionalUsedValues.size(); |
104 | | - SmallVector<Type> newResultTypes(op.getResultTypes()); |
105 | | - newResultTypes.append(additionalValueTypes); |
106 | | - |
107 | | - auto newWhileOp = |
108 | | - scf::WhileOp::create(rewriter, loc, newResultTypes, op.getInits()); |
109 | | - |
110 | | - newWhileOp.getBefore().takeBody(op.getBefore()); |
111 | | - newWhileOp.getAfter().takeBody(op.getAfter()); |
112 | | - newWhileOp.getAfter().addArguments( |
113 | | - additionalValueTypes, SmallVector<Location>(additionalValueSize, loc)); |
114 | | - |
115 | | - SmallVector<Value> conditionArgs = conditionOp.getArgs(); |
116 | | - llvm::append_range(conditionArgs, additionalUsedValues); |
117 | | - |
118 | | - // Update conditionOp inside new whileOp before region. |
119 | | - rewriter.setInsertionPoint(conditionOp); |
120 | | - rewriter.replaceOpWithNewOp<scf::ConditionOp>( |
121 | | - conditionOp, conditionOp.getCondition(), conditionArgs); |
122 | | - |
123 | | - // Replace uses of additional used values inside the ifOp then region with |
124 | | - // the whileOp after region arguments. |
125 | | - rewriter.replaceUsesWithIf( |
126 | | - additionalUsedValues.takeVector(), |
127 | | - newWhileOp.getAfterArguments().take_back(additionalValueSize), |
128 | | - [&](OpOperand &use) { |
129 | | - return ifOp.getThenRegion().isAncestor( |
130 | | - use.getOwner()->getParentRegion()); |
131 | | - }); |
132 | | - |
133 | | - // Inline ifOp then region into new whileOp after region. |
134 | | - rewriter.eraseOp(ifOp.thenYield()); |
135 | | - rewriter.inlineBlockBefore(ifOp.thenBlock(), newWhileOp.getAfterBody(), |
136 | | - newWhileOp.getAfterBody()->begin()); |
137 | | - rewriter.eraseOp(ifOp); |
138 | | - rewriter.replaceOp(op, |
139 | | - newWhileOp->getResults().drop_back(additionalValueSize)); |
140 | | - return success(); |
141 | | - } |
142 | | -}; |
143 | | - |
144 | 22 | struct UpliftWhileOp : public OpRewritePattern<scf::WhileOp> { |
145 | 23 | using OpRewritePattern::OpRewritePattern; |
146 | 24 |
|
@@ -389,6 +267,5 @@ FailureOr<scf::ForOp> mlir::scf::upliftWhileToForLoop(RewriterBase &rewriter, |
389 | 267 | } |
390 | 268 |
|
391 | 269 | void mlir::scf::populateUpliftWhileToForPatterns(RewritePatternSet &patterns) { |
392 | | - patterns.add<WhileMoveIfDown, UpliftWhileOp>(patterns.getContext()); |
393 | | - scf::WhileOp::getCanonicalizationPatterns(patterns, patterns.getContext()); |
| 270 | + patterns.add<UpliftWhileOp>(patterns.getContext()); |
394 | 271 | } |
0 commit comments