-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Removing dead values for branches #117501
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
437db73
1663984
4617d02
473fc46
a84e5d0
9e3d7b0
4114c7e
97a5531
5680ba5
7ebf784
b3e9c10
2b63c0f
6672caa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -165,6 +165,17 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { | |
| return opOperands; | ||
| } | ||
|
|
||
| // Check if any of the operations implements BranchOpInterface | ||
| template <typename UserRange> | ||
| static bool anyBranchUsers(const UserRange &users) { | ||
| for (auto user : users) { | ||
| if (auto subBranchOp = dyn_cast<BranchOpInterface>(user)) { | ||
| return true; | ||
| } | ||
| } | ||
| return false; | ||
| } | ||
|
|
||
| /// Clean a simple op `op`, given the liveness analysis information in `la`. | ||
| /// Here, cleaning means: | ||
| /// (1) Dropping all its uses, AND | ||
|
|
@@ -175,7 +186,8 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) { | |
| /// symbol op, a symbol-user op, a region branch op, a branch op, a region | ||
| /// branch terminator op, or return-like. | ||
| static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) { | ||
| if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la)) | ||
| if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la) || | ||
| anyBranchUsers(op->getUsers())) | ||
| return; | ||
|
|
||
| op->dropAllUses(); | ||
|
|
@@ -563,6 +575,48 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp, | |
| dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip()); | ||
| } | ||
|
|
||
| // 1. Iterate over each successor block of the given BranchOpInterface | ||
| // operation. | ||
| // 2. For each successor block: | ||
| // a. Retrieve the operands passed to the successor. | ||
| // b. Use the provided liveness analysis (`RunLivenessAnalysis`) to determine | ||
| // which | ||
CoTinker marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| // operands are live in the successor block. | ||
| // c. Mark each operand as live or dead based on the analysis. | ||
| // 3. Remove dead operands from the branch operation and arguments accordingly | ||
|
|
||
| static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this needs to be a recursive solution, you can have the same situation in a conditional branch, for example with in a branch you could declare variables and then pass them to the nested conditional branch in this case it won't be part of your initial successor block arguments.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My first idea was to make a recursive function, but I checked other pieces of this pass and I saw no recursion. It looks like it is applied till stable point is reached and no more values getting deleted. In iterative fashion. If doing recursively, we need to check for maximal depth and cyclic dependencies. And do that for any kinds of operations. Do you have an advice @joker-eph ?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function is called in a That said: what about adding a test to cover the case that @codemzs is describing?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @joker-eph completely forgot this was being invoked from walk(). We are good on that front.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added such test, but after looking closely I realized that
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am quite surprised
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it does, the issue is the ordering, I need "inner" branches to be traversed first, that is what PostOrder does, but in case of Branches there is not so much hierarchy. I am going to play with it a bit and get back to you
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It turns out to be simpler, I added a test with branching loop passing multiple dead values around to both conditional and unconditional branches. |
||
| unsigned numSuccessors = branchOp->getNumSuccessors(); | ||
|
|
||
| // Do (1) | ||
| for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { | ||
| Block *successorBlock = branchOp->getSuccessor(succIdx); | ||
|
|
||
| // Do (2) | ||
| SuccessorOperands successorOperands = | ||
| branchOp.getSuccessorOperands(succIdx); | ||
| SmallVector<Value> operandValues; | ||
| for (unsigned operandIdx = 0; operandIdx < successorOperands.size(); | ||
| ++operandIdx) { | ||
| operandValues.push_back(successorOperands[operandIdx]); | ||
| } | ||
|
|
||
| BitVector successorLiveOperands = markLives(operandValues, la); | ||
|
|
||
| // Do (3) | ||
| for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) { | ||
| if (!successorLiveOperands[argIdx]) { | ||
| if (anyBranchUsers(successorBlock->getArgument(argIdx).getUsers())) { | ||
| continue; | ||
| } | ||
|
|
||
| successorOperands.erase(argIdx); | ||
| successorBlock->eraseArgument(argIdx); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| struct RemoveDeadValues : public impl::RemoveDeadValuesBase<RemoveDeadValues> { | ||
| void runOnOperation() override; | ||
| }; | ||
|
|
@@ -572,26 +626,13 @@ void RemoveDeadValues::runOnOperation() { | |
| auto &la = getAnalysis<RunLivenessAnalysis>(); | ||
| Operation *module = getOperation(); | ||
|
|
||
| // The removal of non-live values is performed iff there are no branch ops, | ||
| // and all symbol user ops present in the IR are call-like. | ||
| WalkResult acceptableIR = module->walk([&](Operation *op) { | ||
| if (op == module) | ||
| return WalkResult::advance(); | ||
| if (isa<BranchOpInterface>(op)) { | ||
| op->emitError() << "cannot optimize an IR with branch ops\n"; | ||
| return WalkResult::interrupt(); | ||
| } | ||
| return WalkResult::advance(); | ||
| }); | ||
|
|
||
| if (acceptableIR.wasInterrupted()) | ||
| return signalPassFailure(); | ||
|
|
||
| module->walk([&](Operation *op) { | ||
| if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) { | ||
| cleanFuncOp(funcOp, module, la); | ||
| } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) { | ||
| cleanRegionBranchOp(regionBranchOp, la); | ||
| } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) { | ||
| cleanBranchOp(branchOp, la); | ||
| } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) { | ||
| // Nothing to do here because this is a terminator op and it should be | ||
| // honored with respect to its parent | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -28,22 +28,59 @@ module @named_module_acceptable { | |
|
|
||
| // ----- | ||
|
|
||
| // The IR remains untouched because of the presence of a branch op `cf.cond_br`. | ||
| // The IR is optimized regardless of the presence of a branch op `cf.cond_br`. | ||
| // | ||
| func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) { | ||
| func.func @acceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) { | ||
| %non_live = arith.constant 0 : i32 | ||
| // expected-error @+1 {{cannot optimize an IR with branch ops}} | ||
| // CHECK-NOT: non_live | ||
| cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32) | ||
| ^bb1(%non_live_0 : i32): | ||
| // CHECK-NOT: non_live_0 | ||
| cf.br ^bb3 | ||
| ^bb2(%non_live_1 : i32): | ||
| // CHECK-NOT: non_live_1 | ||
| cf.br ^bb3 | ||
| ^bb3: | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // Arguments of unconditional branch op `cf.br` are properly removed. | ||
| // | ||
| func.func @acceptable_ir_has_cleanable_simple_op_with_unconditional_branch_op(%arg0: i1) { | ||
| %non_live = arith.constant 0 : i32 | ||
| // CHECK-NOT: non_live | ||
| cf.br ^bb1(%non_live : i32) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets add a check for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated, thank you! |
||
| ^bb1(%non_live_1 : i32): | ||
|
||
| // CHECK-NOT: non_live_1 | ||
| cf.br ^bb3(%non_live_1 : i32) | ||
| // CHECK-NOT: non_live_2 | ||
| ^bb3(%non_live_2 : i32): | ||
| return | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // Checking that iter_args are properly handled | ||
| // | ||
| func.func @cleanable_loop_iter_args_value(%arg0: index) -> index { | ||
| %c0 = arith.constant 0 : index | ||
| %c1 = arith.constant 1 : index | ||
| %c10 = arith.constant 10 : index | ||
| %non_live = arith.constant 0 : index | ||
| // CHECK-NOT: non_live | ||
|
||
| %result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) { | ||
| %new_live = arith.addi %live_arg, %i : index | ||
| // CHECK: scf.for %[[ARG_0:.*]] = %c0 to %c10 step %c1 iter_args(%[[ARG_1:.*]] = %arg0) | ||
| scf.yield %new_live, %non_live_arg : index, index | ||
| } | ||
| // CHECK-NOT: result_non_live | ||
| return %result : index | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| // Note that this cleanup cannot be done by the `canonicalize` pass. | ||
| // | ||
| // CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.