Skip to content

Commit af0cbf4

Browse files
support region op.
1 parent 397448a commit af0cbf4

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

mlir/lib/Transforms/HoistPureOps.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,34 @@ static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
5656
}
5757
}
5858

59+
static bool isOpContainBlock(Operation *op, Block *block) {
60+
Operation *parentOp = block->getParentOp();
61+
while (parentOp && parentOp != op) {
62+
parentOp = parentOp->getParentOp();
63+
}
64+
return parentOp == op ? true : false;
65+
}
66+
5967
/// Find the hoisting position for the pure op.
6068
static Value getDestPos(Operation *op) {
6169
DominanceInfo dominanceInfo(op);
6270
SmallVector<Value> operands(op->getOperands());
71+
if (op->getNumRegions()) {
72+
op->walk([&](Operation *operation) {
73+
for (auto operand : operation->getOperands()) {
74+
Operation *defineOp = operand.getDefiningOp();
75+
if (!defineOp) {
76+
BlockArgument argument = cast<BlockArgument>(operand);
77+
if (!isOpContainBlock(op, argument.getOwner()))
78+
operands.push_back(operand);
79+
continue;
80+
}
81+
if (!isOpContainBlock(op, defineOp->getBlock())) {
82+
operands.push_back(operand);
83+
}
84+
}
85+
});
86+
}
6387
if (operands.empty())
6488
return {};
6589
Value ret = operands[0];
@@ -71,13 +95,18 @@ static Value getDestPos(Operation *op) {
7195

7296
/// Hoist single pure op.
7397
static void hoistPureOp(RewriterBase &rewriter, Operation *op) {
98+
LDBG() << "hoistPureOp: " << OpWithFlags(op, OpPrintingFlags().skipRegions());
7499
Value pos = getDestPos(op);
75100
if (!pos)
76101
return;
77102

78103
if (Operation *defineOp = pos.getDefiningOp()) {
104+
if (op == defineOp)
105+
return;
106+
79107
LDBG() << "move " << OpWithFlags(op, OpPrintingFlags().skipRegions())
80-
<< " after " << OpWithFlags(op, OpPrintingFlags().skipRegions());
108+
<< " after "
109+
<< OpWithFlags(defineOp, OpPrintingFlags().skipRegions());
81110
rewriter.moveOpAfter(op, defineOp);
82111
return;
83112
}

0 commit comments

Comments
 (0)