@@ -306,19 +306,17 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
306306 nonLiveSet.insert (arg);
307307 }
308308
309- // Do (2).
309+ // Do (2). (Skip creating generic operand cleanup entries for call ops.
310+ // Call arguments will be removed in the call-site specific segment-aware
311+ // cleanup, avoiding generic eraseOperands bitvector mechanics.)
310312 SymbolTable::UseRange uses = *funcOp.getSymbolUses (module );
311313 for (SymbolTable::SymbolUse use : uses) {
312314 Operation *callOp = use.getUser ();
313315 assert (isa<CallOpInterface>(callOp) && " expected a call-like user" );
314- // The number of operands in the call op may not match the number of
315- // arguments in the func op.
316- BitVector nonLiveCallOperands (callOp->getNumOperands (), false );
317- SmallVector<OpOperand *> callOpOperands =
318- operandsToOpOperands (cast<CallOpInterface>(callOp).getArgOperands ());
319- for (int index : nonLiveArgs.set_bits ())
320- nonLiveCallOperands.set (callOpOperands[index]->getOperandNumber ());
321- cl.operands .push_back ({callOp, nonLiveCallOperands});
316+ // Push an empty operand cleanup entry so that call-site specific logic in
317+ // cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
318+ // intentionally all false to avoid generic erasure.
319+ cl.operands .push_back ({callOp, BitVector (callOp->getNumOperands (), false )});
322320 }
323321
324322 // Do (3).
@@ -746,6 +744,10 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
746744
747745 // 3. Functions
748746 LDBG () << " Cleaning up " << list.functions .size () << " functions" ;
747+ // Record which function arguments were erased so we can shrink call-site
748+ // argument segments for CallOpInterface operations (e.g. ops using
749+ // AttrSizedOperandSegments) in the next phase.
750+ DenseMap<Operation *, BitVector> erasedFuncArgs;
749751 for (auto &f : list.functions ) {
750752 LDBG () << " Cleaning up function: " << f.funcOp .getOperation ()->getName ();
751753 LDBG () << " Erasing " << f.nonLiveArgs .count () << " non-live arguments" ;
@@ -754,17 +756,51 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
754756 // Some functions may not allow erasing arguments or results. These calls
755757 // return failure in such cases without modifying the function, so it's okay
756758 // to proceed.
757- (void )f.funcOp .eraseArguments (f.nonLiveArgs );
759+ if (succeeded (f.funcOp .eraseArguments (f.nonLiveArgs ))) {
760+ // Record only if we actually erased something.
761+ if (f.nonLiveArgs .any ())
762+ erasedFuncArgs.try_emplace (f.funcOp .getOperation (), f.nonLiveArgs );
763+ }
758764 (void )f.funcOp .eraseResults (f.nonLiveRets );
759765 }
760766
761767 // 4. Operands
762768 LDBG () << " Cleaning up " << list.operands .size () << " operand lists" ;
763769 for (OperationToCleanup &o : list.operands ) {
764- if (o.op ->getNumOperands () > 0 ) {
765- LDBG () << " Erasing " << o.nonLive .count ()
766- << " non-live operands from operation: "
767- << OpWithFlags (o.op , OpPrintingFlags ().skipRegions ());
770+ if (auto call = dyn_cast<CallOpInterface>(o.op )) {
771+ if (SymbolRefAttr sym = call.getCallableForCallee ().dyn_cast <SymbolRefAttr>()) {
772+ Operation *callee = SymbolTable::lookupNearestSymbolFrom (o.op , sym);
773+ auto it = erasedFuncArgs.find (callee);
774+ if (it != erasedFuncArgs.end ()) {
775+ const BitVector &deadArgIdxs = it->second ;
776+ MutableOperandRange args = call.getArgOperandsMutable ();
777+ // First, erase the call arguments corresponding to erased callee args.
778+ for (int i = static_cast <int >(args.size ()) - 1 ; i >= 0 ; --i) {
779+ if (i < static_cast <int >(deadArgIdxs.size ()) && deadArgIdxs.test (i))
780+ args.erase (i);
781+ }
782+ // If this operand cleanup entry also has a generic nonLive bitvector,
783+ // clear bits for call arguments we already erased above to avoid
784+ // double-erasing (which could impact other segments of ops with
785+ // AttrSizedOperandSegments).
786+ if (o.nonLive .any ()) {
787+ // Map the argument logical index to the operand number(s) recorded.
788+ SmallVector<OpOperand *> callOperands =
789+ operandsToOpOperands (call.getArgOperands ());
790+ for (int argIdx : deadArgIdxs.set_bits ()) {
791+ if (argIdx < static_cast <int >(callOperands.size ())) {
792+ unsigned operandNumber = callOperands[argIdx]->getOperandNumber ();
793+ if (operandNumber < o.nonLive .size ())
794+ o.nonLive .reset (operandNumber);
795+ }
796+ }
797+ }
798+ }
799+ }
800+ }
801+ // Only perform generic operand erasure for non-call ops; for call ops we
802+ // already handled argument removals via the segment-aware path above.
803+ if (!isa<CallOpInterface>(o.op ) && o.nonLive .any ()) {
768804 o.op ->eraseOperands (o.nonLive );
769805 }
770806 }
0 commit comments