@@ -56,10 +56,34 @@ static Value getDomaincedValue(DominanceInfo &dominanceInfo, Value a, Value b) {
5656 }
5757}
5858
59+ static bool isOpContainBlock (Operation *op, Block *block) {
60+ Operation *parentOp = block->getParentOp ();
61+ while (parentOp && parentOp != op) {
62+ parentOp = parentOp->getParentOp ();
63+ }
64+ return parentOp == op ? true : false ;
65+ }
66+
5967// / Find the hoisting position for the pure op.
6068static Value getDestPos (Operation *op) {
6169 DominanceInfo dominanceInfo (op);
6270 SmallVector<Value> operands (op->getOperands ());
71+ if (op->getNumRegions ()) {
72+ op->walk ([&](Operation *operation) {
73+ for (auto operand : operation->getOperands ()) {
74+ Operation *defineOp = operand.getDefiningOp ();
75+ if (!defineOp) {
76+ BlockArgument argument = cast<BlockArgument>(operand);
77+ if (!isOpContainBlock (op, argument.getOwner ()))
78+ operands.push_back (operand);
79+ continue ;
80+ }
81+ if (!isOpContainBlock (op, defineOp->getBlock ())) {
82+ operands.push_back (operand);
83+ }
84+ }
85+ });
86+ }
6387 if (operands.empty ())
6488 return {};
6589 Value ret = operands[0 ];
@@ -71,13 +95,18 @@ static Value getDestPos(Operation *op) {
7195
7296// / Hoist single pure op.
7397static void hoistPureOp (RewriterBase &rewriter, Operation *op) {
98+ LDBG () << " hoistPureOp: " << OpWithFlags (op, OpPrintingFlags ().skipRegions ());
7499 Value pos = getDestPos (op);
75100 if (!pos)
76101 return ;
77102
78103 if (Operation *defineOp = pos.getDefiningOp ()) {
104+ if (op == defineOp)
105+ return ;
106+
79107 LDBG () << " move " << OpWithFlags (op, OpPrintingFlags ().skipRegions ())
80- << " after " << OpWithFlags (op, OpPrintingFlags ().skipRegions ());
108+ << " after "
109+ << OpWithFlags (defineOp, OpPrintingFlags ().skipRegions ());
81110 rewriter.moveOpAfter (op, defineOp);
82111 return ;
83112 }
0 commit comments