diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 1e0ec41fdc9b..add9d3a93c70 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -4255,6 +4255,16 @@ class CIR_CallOp extra_traits = []> // For indirect call, the operand list is shifted by one. setOperand(index + 1, value); } + + /// If this is a direct call, returns the callee as a `cir::FuncOp` in `symbolTable`. + /// Otherwise, returns `null`. + cir::FuncOp getDirectCallee(mlir::SymbolTable &symbolTable); + + /// If this is a direct call, returns the callee as a `cir::FuncOp` in parent `module`. + /// Otherwise, returns `null`. + /// NOTE: This method walks the symbol table. If you are calling this method a lot, + /// consider using `cir::FuncOp::getDirectCallee(mlir::SymbolTable &)` instead. + cir::FuncOp getDirectCallee(mlir::ModuleOp module); }]; let hasCustomAssemblyFormat = 1; diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 14528fd36675..e353170cdb23 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -2950,6 +2950,21 @@ unsigned cir::CallOp::getNumArgOperands() { return this->getOperation()->getNumOperands(); } +cir::FuncOp cir::CallOp::getDirectCallee(mlir::SymbolTable &symbolTable) { + if (!getCallee()) + return {}; + llvm::StringRef name = *getCallee(); + return symbolTable.lookup(name); +} + +cir::FuncOp cir::CallOp::getDirectCallee(mlir::ModuleOp module) { + if (!getCallee()) + return {}; + llvm::StringRef name = *getCallee(); + mlir::Operation *global = mlir::SymbolTable::lookupSymbolIn(module, name); + return mlir::dyn_cast_if_present(global); +} + static LogicalResult verifyCallCommInSymbolUses(Operation *op, SymbolTableCollection &symbolTable) { // Callee attribute only need on indirect calls. diff --git a/clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp b/clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp index 7f980b587fe2..a528687adf30 100644 --- a/clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp @@ -1443,17 +1443,8 @@ void LifetimeCheckPass::checkPointerDeref(mlir::Value addr, mlir::Location loc, emitPsetRemark(); } -static FuncOp getCalleeFromSymbol(ModuleOp mod, llvm::StringRef name) { - auto global = mlir::SymbolTable::lookupSymbolIn(mod, name); - assert(global && "expected to find symbol for function"); - return dyn_cast(global); -} - static const ASTCXXMethodDeclInterface getMethod(ModuleOp mod, CallOp callOp) { - if (!callOp.getCallee()) - return nullptr; - llvm::StringRef name = *callOp.getCallee(); - auto method = getCalleeFromSymbol(mod, name); + cir::FuncOp method = callOp.getDirectCallee(mod); if (!method || method.getBuiltin()) return nullptr; return dyn_cast(method.getAstAttr()); @@ -1756,12 +1747,10 @@ bool LifetimeCheckPass::isTaskType(mlir::Value taskVal) { } void LifetimeCheckPass::trackCallToCoroutine(CallOp callOp) { - if (auto fnName = callOp.getCallee()) { - auto calleeFuncOp = getCalleeFromSymbol(theModule, *fnName); - if (calleeFuncOp && - (calleeFuncOp.getCoroutine() || - (calleeFuncOp.isDeclaration() && callOp->getNumResults() > 0 && - isTaskType(callOp->getResult(0))))) { + if (cir::FuncOp callee = callOp.getDirectCallee(theModule)) { + if (callee.getCoroutine() || + (callee.isDeclaration() && callOp->getNumResults() > 0 && + isTaskType(callOp->getResult(0)))) { currScope->localTempTasks.insert(callOp->getResult(0)); } return; @@ -1792,13 +1781,11 @@ void LifetimeCheckPass::checkCall(CallOp callOp) { // From this point on only owner and pointer class methods handling, // starting from special methods. - if (auto fnName = callOp.getCallee()) { - auto calleeFuncOp = getCalleeFromSymbol(theModule, *fnName); - if (calleeFuncOp && calleeFuncOp.getCxxSpecialMember()) - if (auto cxxCtor = - dyn_cast(*calleeFuncOp.getCxxSpecialMember())) - return checkCtor(callOp, cxxCtor); - } + cir::FuncOp callee = callOp.getDirectCallee(theModule); + if (callee && callee.getCxxSpecialMember()) + if (auto cxxCtor = + dyn_cast(*callee.getCxxSpecialMember())) + return checkCtor(callOp, cxxCtor); if (methodDecl.isMoveAssignmentOperator()) return checkMoveAssignment(callOp, methodDecl); if (methodDecl.isCopyAssignmentOperator())