@@ -172,7 +172,7 @@ static SmallVector<OpOperand *> operandsToOpOperands(OperandRange operands) {
172172// / iff it has no memory effects and none of its results are live.
173173// /
174174// / It is assumed that `op` is simple. Here, a simple op is one which isn't a
175- // / symbol op, a symbol-user op, a region branch op, a branch op, a region
175+ // / function-like op, a call-like op, a region branch op, a branch op, a region
176176// / branch terminator op, or return-like.
177177static void cleanSimpleOp (Operation *op, RunLivenessAnalysis &la) {
178178 if (!isMemoryEffectFree (op) || hasLive (op->getResults (), la))
@@ -563,6 +563,51 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
563563 dropUsesAndEraseResults (regionBranchOp.getOperation (), resultsToKeep.flip ());
564564}
565565
566+ // 1. Iterate over each successor block of the given BranchOpInterface
567+ // operation.
568+ // 2. For each successor block:
569+ // a. Retrieve the operands passed to the successor.
570+ // b. Use the provided liveness analysis (`RunLivenessAnalysis`) to determine
571+ // which operands are live in the successor block.
572+ // c. Mark each operand as live or dead based on the analysis.
573+ // 3. Remove dead operands from the branch operation and arguments accordingly
574+
575+ static void cleanBranchOp (BranchOpInterface branchOp, RunLivenessAnalysis &la) {
576+ unsigned numSuccessors = branchOp->getNumSuccessors ();
577+
578+ // Do (1)
579+ for (unsigned succIdx = 0 ; succIdx < numSuccessors; ++succIdx) {
580+ Block *successorBlock = branchOp->getSuccessor (succIdx);
581+
582+ // Do (2)
583+ SuccessorOperands successorOperands =
584+ branchOp.getSuccessorOperands (succIdx);
585+ SmallVector<Value> operandValues;
586+ for (unsigned operandIdx = 0 ; operandIdx < successorOperands.size ();
587+ ++operandIdx) {
588+ operandValues.push_back (successorOperands[operandIdx]);
589+ }
590+
591+ BitVector successorLiveOperands = markLives (operandValues, la);
592+
593+ // Do (3)
594+ for (int argIdx = successorLiveOperands.size () - 1 ; argIdx >= 0 ; --argIdx) {
595+ if (!successorLiveOperands[argIdx]) {
596+ if (successorBlock->getNumArguments () < successorOperands.size ()) {
597+ // if block was cleaned through a different code path
598+ // we only need to remove operands from the invokation
599+ successorOperands.erase (argIdx);
600+ continue ;
601+ }
602+
603+ successorBlock->getArgument (argIdx).dropAllUses ();
604+ successorOperands.erase (argIdx);
605+ successorBlock->eraseArgument (argIdx);
606+ }
607+ }
608+ }
609+ }
610+
566611struct RemoveDeadValues : public impl ::RemoveDeadValuesBase<RemoveDeadValues> {
567612 void runOnOperation () override ;
568613};
@@ -572,26 +617,13 @@ void RemoveDeadValues::runOnOperation() {
572617 auto &la = getAnalysis<RunLivenessAnalysis>();
573618 Operation *module = getOperation ();
574619
575- // The removal of non-live values is performed iff there are no branch ops,
576- // and all symbol user ops present in the IR are call-like.
577- WalkResult acceptableIR = module ->walk ([&](Operation *op) {
578- if (op == module )
579- return WalkResult::advance ();
580- if (isa<BranchOpInterface>(op)) {
581- op->emitError () << " cannot optimize an IR with branch ops\n " ;
582- return WalkResult::interrupt ();
583- }
584- return WalkResult::advance ();
585- });
586-
587- if (acceptableIR.wasInterrupted ())
588- return signalPassFailure ();
589-
590620 module ->walk ([&](Operation *op) {
591621 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
592622 cleanFuncOp (funcOp, module , la);
593623 } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
594624 cleanRegionBranchOp (regionBranchOp, la);
625+ } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
626+ cleanBranchOp (branchOp, la);
595627 } else if (op->hasTrait <::mlir::OpTrait::IsTerminator>()) {
596628 // Nothing to do here because this is a terminator op and it should be
597629 // honored with respect to its parent
0 commit comments