@@ -563,6 +563,44 @@ static void cleanRegionBranchOp(RegionBranchOpInterface regionBranchOp,
563563 dropUsesAndEraseResults (regionBranchOp.getOperation (), resultsToKeep.flip ());
564564}
565565
566+ // 1. Iterate over each successor block of the given BranchOpInterface
567+ // operation.
568+ // 2. For each successor block:
569+ // a. Retrieve the operands passed to the successor.
570+ // b. Use the provided liveness analysis (`RunLivenessAnalysis`) to determine
571+ // which
572+ // operands are live in the successor block.
573+ // c. Mark each operand as live or dead based on the analysis.
574+ // 3. Remove dead operands from the branch operation and arguments accordingly
575+
576+ static void cleanBranchOp (BranchOpInterface branchOp, RunLivenessAnalysis &la) {
577+ unsigned numSuccessors = branchOp->getNumSuccessors ();
578+
579+ // Do (1)
580+ for (unsigned succIdx = 0 ; succIdx < numSuccessors; ++succIdx) {
581+ Block *successorBlock = branchOp->getSuccessor (succIdx);
582+
583+ // Do (2)
584+ SuccessorOperands successorOperands =
585+ branchOp.getSuccessorOperands (succIdx);
586+ SmallVector<Value> operandValues;
587+ for (unsigned operandIdx = 0 ; operandIdx < successorOperands.size ();
588+ ++operandIdx) {
589+ operandValues.push_back (successorOperands[operandIdx]);
590+ }
591+
592+ BitVector successorLiveOperands = markLives (operandValues, la);
593+
594+ // Do (3)
595+ for (int argIdx = successorLiveOperands.size () - 1 ; argIdx >= 0 ; --argIdx) {
596+ if (!successorLiveOperands[argIdx]) {
597+ successorOperands.erase (argIdx);
598+ successorBlock->eraseArgument (argIdx);
599+ }
600+ }
601+ }
602+ }
603+
566604struct RemoveDeadValues : public impl ::RemoveDeadValuesBase<RemoveDeadValues> {
567605 void runOnOperation () override ;
568606};
@@ -572,26 +610,13 @@ void RemoveDeadValues::runOnOperation() {
572610 auto &la = getAnalysis<RunLivenessAnalysis>();
573611 Operation *module = getOperation ();
574612
575- // The removal of non-live values is performed iff there are no branch ops,
576- // and all symbol user ops present in the IR are call-like.
577- WalkResult acceptableIR = module ->walk ([&](Operation *op) {
578- if (op == module )
579- return WalkResult::advance ();
580- if (isa<BranchOpInterface>(op)) {
581- op->emitError () << " cannot optimize an IR with branch ops\n " ;
582- return WalkResult::interrupt ();
583- }
584- return WalkResult::advance ();
585- });
586-
587- if (acceptableIR.wasInterrupted ())
588- return signalPassFailure ();
589-
590613 module ->walk ([&](Operation *op) {
591614 if (auto funcOp = dyn_cast<FunctionOpInterface>(op)) {
592615 cleanFuncOp (funcOp, module , la);
593616 } else if (auto regionBranchOp = dyn_cast<RegionBranchOpInterface>(op)) {
594617 cleanRegionBranchOp (regionBranchOp, la);
618+ } else if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
619+ cleanBranchOp (branchOp, la);
595620 } else if (op->hasTrait <::mlir::OpTrait::IsTerminator>()) {
596621 // Nothing to do here because this is a terminator op and it should be
597622 // honored with respect to its parent
0 commit comments