diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 0aa9dcb36681b..4241974ee290d 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -172,7 +172,7 @@ static SmallVector operandsToOpOperands(OperandRange operands) { /// iff it has no memory effects and none of its results are live. /// /// It is assumed that `op` is simple. Here, a simple op is one which isn't a -/// symbol op, a symbol-user op, a region branch op, a branch op, a region +/// function-like op, a call-like op, a region branch op, a branch op, a region /// branch terminator op, or return-like. static void cleanSimpleOp(Operation *op, RunLivenessAnalysis &la) { if (!isMemoryEffectFree(op) || hasLive(op->getResults(), la)) @@ -563,6 +563,51 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp, dropUsesAndEraseResults(regionBranchOp.getOperation(), resultsToKeep.flip()); } +// 1. Iterate over each successor block of the given BranchOpInterface +// operation. +// 2. For each successor block: +// a. Retrieve the operands passed to the successor. +// b. Use the provided liveness analysis (`RunLivenessAnalysis`) to determine +// which operands are live in the successor block. +// c. Mark each operand as live or dead based on the analysis. +// 3. Remove dead operands from the branch operation and arguments accordingly + +static void cleanBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la) { + unsigned numSuccessors = branchOp->getNumSuccessors(); + + // Do (1) + for (unsigned succIdx = 0; succIdx < numSuccessors; ++succIdx) { + Block *successorBlock = branchOp->getSuccessor(succIdx); + + // Do (2) + SuccessorOperands successorOperands = + branchOp.getSuccessorOperands(succIdx); + SmallVector operandValues; + for (unsigned operandIdx = 0; operandIdx < successorOperands.size(); + ++operandIdx) { + operandValues.push_back(successorOperands[operandIdx]); + } + + BitVector successorLiveOperands = markLives(operandValues, la); + + // Do (3) + for (int argIdx = successorLiveOperands.size() - 1; argIdx >= 0; --argIdx) { + if (!successorLiveOperands[argIdx]) { + if (successorBlock->getNumArguments() < successorOperands.size()) { + // if block was cleaned through a different code path + // we only need to remove operands from the invokation + successorOperands.erase(argIdx); + continue; + } + + successorBlock->getArgument(argIdx).dropAllUses(); + successorOperands.erase(argIdx); + successorBlock->eraseArgument(argIdx); + } + } + } +} + struct RemoveDeadValues : public impl::RemoveDeadValuesBase { void runOnOperation() override; }; @@ -572,26 +617,13 @@ void RemoveDeadValues::runOnOperation() { auto &la = getAnalysis(); Operation *module = getOperation(); - // The removal of non-live values is performed iff there are no branch ops, - // and all symbol user ops present in the IR are call-like. - WalkResult acceptableIR = module->walk([&](Operation *op) { - if (op == module) - return WalkResult::advance(); - if (isa(op)) { - op->emitError() << "cannot optimize an IR with branch ops\n"; - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - - if (acceptableIR.wasInterrupted()) - return signalPassFailure(); - module->walk([&](Operation *op) { if (auto funcOp = dyn_cast(op)) { cleanFuncOp(funcOp, module, la); } else if (auto regionBranchOp = dyn_cast(op)) { cleanRegionBranchOp(regionBranchOp, la); + } else if (auto branchOp = dyn_cast(op)) { + cleanBranchOp(branchOp, la); } else if (op->hasTrait<::mlir::OpTrait::IsTerminator>()) { // Nothing to do here because this is a terminator op and it should be // honored with respect to its parent diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 826f6159a36b6..5f06a54a1ef8b 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -28,22 +28,51 @@ module @named_module_acceptable { // ----- -// The IR remains untouched because of the presence of a branch op `cf.cond_br`. +// The IR contains both conditional and unconditional branches with a loop +// in which the last cf.cond_br is referncing the first cf.br // -func.func @dont_touch_unacceptable_ir_has_cleanable_simple_op_with_branch_op(%arg0: i1) { +func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0: i1) { %non_live = arith.constant 0 : i32 - // expected-error @+1 {{cannot optimize an IR with branch ops}} - cf.cond_br %arg0, ^bb1(%non_live : i32), ^bb2(%non_live : i32) -^bb1(%non_live_0 : i32): - cf.br ^bb3 -^bb2(%non_live_1 : i32): - cf.br ^bb3 -^bb3: + // CHECK-NOT: arith.constant + cf.br ^bb1(%non_live : i32) + // CHECK: cf.br ^[[BB1:bb[0-9]+]] +^bb1(%non_live_1 : i32): + // CHECK: ^[[BB1]]: + %non_live_5 = arith.constant 1 : i32 + cf.br ^bb3(%non_live_1, %non_live_5 : i32, i32) + // CHECK: cf.br ^[[BB3:bb[0-9]+]] + // CHECK-NOT: i32 +^bb3(%non_live_2 : i32, %non_live_6 : i32): + // CHECK: ^[[BB3]]: + cf.cond_br %arg0, ^bb1(%non_live_2 : i32), ^bb4(%non_live_2 : i32) + // CHECK: cf.cond_br %arg0, ^[[BB1]], ^[[BB4:bb[0-9]+]] +^bb4(%non_live_4 : i32): + // CHECK: ^[[BB4]]: return } // ----- +// Checking that iter_args are properly handled +// +func.func @cleanable_loop_iter_args_value(%arg0: index) -> index { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c10 = arith.constant 10 : index + %non_live = arith.constant 0 : index + // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) { + %result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) { + // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index + %new_live = arith.addi %live_arg, %i : index + // CHECK: scf.yield [[SUM:%.+]] + scf.yield %new_live, %non_live_arg : index, index + } + // CHECK: return [[RESULT]] : index + return %result : index +} + +// ----- + // Note that this cleanup cannot be done by the `canonicalize` pass. // // CHECK-LABEL: func.func private @clean_func_op_remove_argument_and_return_value() {