@@ -1073,6 +1073,8 @@ class ThreadSafetyAnalyzer {
10731073 ProtectedOperationKind POK);
10741074 void checkPtAccess (const FactSet &FSet, const Expr *Exp, AccessKind AK,
10751075 ProtectedOperationKind POK);
1076+
1077+ void checkMismatchedFunctionAttrs (const FunctionDecl *FD);
10761078};
10771079
10781080} // namespace
@@ -2263,34 +2265,25 @@ static bool neverReturns(const CFGBlock *B) {
22632265 return false ;
22642266}
22652267
2266- template <typename AttrT>
2267- static SmallVector<const Expr *> collectAttrArgs (const FunctionDecl *FD) {
2268- SmallVector<const Expr *> Args;
2269- for (const AttrT *A : FD->specific_attrs <AttrT>()) {
2270- for (const Expr *E : A->args ())
2271- Args.push_back (E);
2272- }
2268+ void ThreadSafetyAnalyzer::checkMismatchedFunctionAttrs (
2269+ const FunctionDecl *FD) {
2270+ FD = FD->getMostRecentDecl ();
22732271
2274- return Args;
2275- }
2276-
2277- static void diagnoseMismatchedFunctionAttrs (const FunctionDecl *FD,
2278- ThreadSafetyHandler &Handler) {
2279- assert (FD);
2280- FD = FD->getDefinition ();
2281- assert (FD);
2282- auto FDArgs = collectAttrArgs<RequiresCapabilityAttr>(FD);
2272+ auto collectCapabilities = [&](const FunctionDecl *FD) {
2273+ SmallVector<CapabilityExpr> Args;
2274+ for (const auto *A : FD->specific_attrs <RequiresCapabilityAttr>()) {
2275+ for (const Expr *E : A->args ())
2276+ Args.push_back (SxBuilder.translateAttrExpr (E, nullptr ));
2277+ }
2278+ return Args;
2279+ };
22832280
2281+ auto FDArgs = collectCapabilities (FD);
22842282 for (const FunctionDecl *D = FD->getPreviousDecl (); D;
22852283 D = D->getPreviousDecl ()) {
2286- auto DArgs = collectAttrArgs<RequiresCapabilityAttr>(D);
2287-
2288- for (const Expr *E : FDArgs) {
2289- if (!llvm::is_contained (DArgs, E)) {
2290- // FD requires E, but D doesn't.
2291- Handler.handleAttributeMismatch (FD, D);
2292- }
2293- }
2284+ auto DArgs = collectCapabilities (D);
2285+ if (DArgs.size () != FDArgs.size ())
2286+ Handler.handleAttributeMismatch (FD, D);
22942287 }
22952288}
22962289
@@ -2314,7 +2307,7 @@ void ThreadSafetyAnalyzer::runAnalysis(AnalysisDeclContext &AC) {
23142307 CurrentFunction = dyn_cast<FunctionDecl>(D);
23152308
23162309 if (CurrentFunction)
2317- diagnoseMismatchedFunctionAttrs (CurrentFunction, Handler );
2310+ checkMismatchedFunctionAttrs (CurrentFunction);
23182311
23192312 if (D->hasAttr <NoThreadSafetyAnalysisAttr>())
23202313 return ;
0 commit comments