Skip to content

Commit e3d5d54

Browse files
Avoiding the expensive symbol look-up
1 parent 78ae4f1 commit e3d5d54

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

mlir/lib/Transforms/RemoveDeadValues.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ struct FunctionToCleanUp {
8888
struct OperationToCleanup {
8989
Operation *op;
9090
BitVector nonLive;
91+
Operation *callee = nullptr; // Optional: For CallOpInterface ops, stores the callee function
9192
};
9293

9394
struct BlockArgsToCleanup {
@@ -316,7 +317,8 @@ static void processFuncOp(FunctionOpInterface funcOp, Operation *module,
316317
// Push an empty operand cleanup entry so that call-site specific logic in
317318
// cleanUpDeadVals runs (it keys off CallOpInterface). The BitVector is
318319
// intentionally all false to avoid generic erasure.
319-
cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false)});
320+
// Store the funcOp as the callee to avoid expensive symbol lookup later.
321+
cl.operands.push_back({callOp, BitVector(callOp->getNumOperands(), false), funcOp.getOperation()});
320322
}
321323

322324
// Do (3).
@@ -768,9 +770,9 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
768770
LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
769771
for (OperationToCleanup &o : list.operands) {
770772
if (auto call = dyn_cast<CallOpInterface>(o.op)) {
771-
if (SymbolRefAttr sym = call.getCallableForCallee().dyn_cast<SymbolRefAttr>()) {
772-
Operation *callee = SymbolTable::lookupNearestSymbolFrom(o.op, sym);
773-
auto it = erasedFuncArgs.find(callee);
773+
// Use the stored callee reference if available, avoiding expensive symbol lookup
774+
if (o.callee) {
775+
auto it = erasedFuncArgs.find(o.callee);
774776
if (it != erasedFuncArgs.end()) {
775777
const BitVector &deadArgIdxs = it->second;
776778
MutableOperandRange args = call.getArgOperandsMutable();
@@ -788,7 +790,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
788790
int operandOffset = call.getArgOperands().getBeginOperandIndex();
789791
for (int argIdx : deadArgIdxs.set_bits()) {
790792
int operandNumber = operandOffset + argIdx;
791-
if (operandNumber < o.nonLive.size())
793+
if (operandNumber < static_cast<int>(o.nonLive.size()))
792794
o.nonLive.reset(operandNumber);
793795
}
794796
}

0 commit comments

Comments
 (0)