5252#include " clang/AST/DeclTemplate.h"
5353#include " clang/AST/Expr.h"
5454#include " clang/AST/ExprCXX.h"
55- #include " clang/AST/TemplateBase.h"
56- #include " clang/AST/Type.h"
5755
5856#include " clang/AST/ParentMap.h"
5957#include " clang/ASTMatchers/ASTMatchFinder.h"
@@ -1113,17 +1111,23 @@ class StopTrackingCallback final : public SymbolVisitor {
11131111// / The visitor traverses reachable symbols from a given set of memory regions
11141112// / (typically smart pointer field regions) and marks any allocated symbols as
11151113// / escaped. Escaped symbols are not reported as leaks by checkDeadSymbols.
1116- // /
1117- // / Usage:
1118- // / auto Scan =
1119- // / State->scanReachableSymbols<EscapeTrackedCallback>(RootRegions);
1120- // / ProgramStateRef NewState = Scan.getState();
1121- // / if (NewState != State) C.addTransition(NewState);
11221114class EscapeTrackedCallback final : public SymbolVisitor {
11231115 ProgramStateRef State;
11241116
1125- public:
11261117 explicit EscapeTrackedCallback (ProgramStateRef S) : State(std::move(S)) {}
1118+
1119+ public:
1120+ // / Escape tracked regions reachable from the given roots.
1121+ static ProgramStateRef
1122+ EscapeTrackedRegionsReachableFrom (ArrayRef<const MemRegion *> Roots,
1123+ ProgramStateRef State) {
1124+ EscapeTrackedCallback Visitor (State);
1125+ for (const MemRegion *R : Roots) {
1126+ State->scanReachableSymbols (loc::MemRegionVal (R), Visitor);
1127+ }
1128+ return Visitor.getState ();
1129+ }
1130+
11271131 ProgramStateRef getState () const { return State; }
11281132
11291133 bool VisitSymbol (SymbolRef Sym) override {
@@ -3107,24 +3111,13 @@ void MallocChecker::checkDeadSymbols(SymbolReaper &SymReaper,
31073111 C.addTransition (state->set <RegionState>(RS), N);
31083112}
31093113
3110- static QualType canonicalStrip (QualType QT) {
3111- return QT.getCanonicalType ().getUnqualifiedType ();
3112- }
3113-
3114- static bool isInStdNamespace (const DeclContext *DC) {
3115- while (DC) {
3116- if (const auto *NS = dyn_cast<NamespaceDecl>(DC))
3117- if (NS->isStdNamespace ())
3118- return true ;
3119- DC = DC->getParent ();
3120- }
3121- return false ;
3122- }
3114+ // Use isWithinStdNamespace from CheckerHelpers.h instead of custom
3115+ // implementation
31233116
31243117// Allowlist of owning smart pointers we want to recognize.
31253118// Start with unique_ptr and shared_ptr. (intentionally exclude weak_ptr)
31263119static bool isSmartOwningPtrType (QualType QT) {
3127- QT = canonicalStrip (QT );
3120+ QT = QT-> getCanonicalTypeUnqualified ( );
31283121
31293122 // First try TemplateSpecializationType (for std smart pointers)
31303123 const auto *TST = QT->getAs <TemplateSpecializationType>();
@@ -3138,30 +3131,75 @@ static bool isSmartOwningPtrType(QualType QT) {
31383131 return false ;
31393132
31403133 // Check if it's in std namespace
3141- const DeclContext *DC = ND->getDeclContext ();
3142- if (!isInStdNamespace (DC))
3134+ if (!isWithinStdNamespace (ND))
31433135 return false ;
31443136
31453137 StringRef Name = ND->getName ();
31463138 return Name == " unique_ptr" || Name == " shared_ptr" ;
31473139 }
31483140
31493141 // Also try RecordType (for custom smart pointer implementations)
3150- const auto *RT = QT->getAs <RecordType>();
3151- if (RT) {
3152- const auto *RD = RT->getDecl ();
3153- if (RD) {
3154- StringRef Name = RD->getName ();
3155- if (Name == " unique_ptr" || Name == " shared_ptr" ) {
3156- // Accept any custom unique_ptr or shared_ptr implementation
3157- return true ;
3158- }
3142+ const auto *RD = QT->getAsCXXRecordDecl ();
3143+ if (RD) {
3144+ StringRef Name = RD->getName ();
3145+ if (Name == " unique_ptr" || Name == " shared_ptr" ) {
3146+ // Accept any custom unique_ptr or shared_ptr implementation
3147+ return true ;
31593148 }
31603149 }
31613150
31623151 return false ;
31633152}
31643153
3154+ static bool hasSmartPtrField (const CXXRecordDecl *CRD) {
3155+ // Check direct fields
3156+ if (llvm::any_of (CRD->fields (), [](const FieldDecl *FD) {
3157+ return isSmartOwningPtrType (FD->getType ());
3158+ }))
3159+ return true ;
3160+
3161+ // Check fields from base classes
3162+ for (const CXXBaseSpecifier &Base : CRD->bases ()) {
3163+ if (const CXXRecordDecl *BaseDecl = Base.getType ()->getAsCXXRecordDecl ()) {
3164+ if (hasSmartPtrField (BaseDecl))
3165+ return true ;
3166+ }
3167+ }
3168+ return false ;
3169+ }
3170+
3171+ static bool isRvalueByValueRecord (const Expr *AE) {
3172+ if (AE->isGLValue ())
3173+ return false ;
3174+
3175+ QualType T = AE->getType ();
3176+ if (!T->isRecordType () || T->isReferenceType ())
3177+ return false ;
3178+
3179+ // Accept common temp/construct forms but don't overfit.
3180+ return isa<CXXTemporaryObjectExpr, MaterializeTemporaryExpr, CXXConstructExpr,
3181+ InitListExpr, ImplicitCastExpr, CXXBindTemporaryExpr>(AE);
3182+ }
3183+
3184+ static bool isRvalueByValueRecordWithSmartPtr (const Expr *AE) {
3185+ if (!isRvalueByValueRecord (AE))
3186+ return false ;
3187+
3188+ const auto *CRD = AE->getType ()->getAsCXXRecordDecl ();
3189+ return CRD && hasSmartPtrField (CRD);
3190+ }
3191+
3192+ static ProgramStateRef escapeAllAllocatedSymbols (ProgramStateRef State) {
3193+ RegionStateTy RS = State->get <RegionState>();
3194+ ProgramStateRef NewState = State;
3195+ for (auto [Sym, RefSt] : RS) {
3196+ if (RefSt.isAllocated () || RefSt.isAllocatedOfSizeZero ()) {
3197+ NewState = NewState->set <RegionState>(Sym, RefState::getEscaped (&RefSt));
3198+ }
3199+ }
3200+ return NewState;
3201+ }
3202+
31653203static void collectDirectSmartOwningPtrFieldRegions (
31663204 const MemRegion *Base, QualType RecQT, CheckerContext &C,
31673205 SmallVectorImpl<const MemRegion *> &Out) {
@@ -3171,13 +3209,29 @@ static void collectDirectSmartOwningPtrFieldRegions(
31713209 if (!CRD)
31723210 return ;
31733211
3212+ // Collect direct fields
31743213 for (const FieldDecl *FD : CRD->fields ()) {
31753214 if (!isSmartOwningPtrType (FD->getType ()))
31763215 continue ;
31773216 SVal L = C.getState ()->getLValue (FD, loc::MemRegionVal (Base));
31783217 if (const MemRegion *FR = L.getAsRegion ())
31793218 Out.push_back (FR);
31803219 }
3220+
3221+ // Collect fields from base classes
3222+ for (const CXXBaseSpecifier &BaseSpec : CRD->bases ()) {
3223+ if (const CXXRecordDecl *BaseDecl =
3224+ BaseSpec.getType ()->getAsCXXRecordDecl ()) {
3225+ // Get the base class region
3226+ SVal BaseL = C.getState ()->getLValue (BaseDecl, Base->getAs <SubRegion>(),
3227+ BaseSpec.isVirtual ());
3228+ if (const MemRegion *BaseRegion = BaseL.getAsRegion ()) {
3229+ // Recursively collect fields from this base class
3230+ collectDirectSmartOwningPtrFieldRegions (BaseRegion, BaseSpec.getType (),
3231+ C, Out);
3232+ }
3233+ }
3234+ }
31813235}
31823236
31833237void MallocChecker::checkPostCall (const CallEvent &Call,
@@ -3195,38 +3249,7 @@ void MallocChecker::checkPostCall(const CallEvent &Call,
31953249 continue ;
31963250 AE = AE->IgnoreParenImpCasts ();
31973251
3198- QualType T = AE->getType ();
3199-
3200- // **Relaxation 1**: accept *any rvalue* by-value record (not only strict
3201- // PRVALUE).
3202- if (AE->isGLValue ())
3203- continue ;
3204-
3205- // By-value record only (no refs).
3206- if (!T->isRecordType () || T->isReferenceType ())
3207- continue ;
3208-
3209- // **Relaxation 2**: accept common temp/construct forms but don't overfit.
3210- const bool LooksLikeTemp =
3211- isa<CXXTemporaryObjectExpr>(AE) || isa<MaterializeTemporaryExpr>(AE) ||
3212- isa<CXXConstructExpr>(AE) || isa<InitListExpr>(AE) ||
3213- isa<ImplicitCastExpr>(AE) || // handle common rvalue materializations
3214- isa<CXXBindTemporaryExpr>(AE); // handle CXXBindTemporaryExpr
3215- if (!LooksLikeTemp)
3216- continue ;
3217-
3218- // Require at least one direct smart owning pointer field by type.
3219- const auto *CRD = T->getAsCXXRecordDecl ();
3220- if (!CRD)
3221- continue ;
3222- bool HasSmartPtrField = false ;
3223- for (const FieldDecl *FD : CRD->fields ()) {
3224- if (isSmartOwningPtrType (FD->getType ())) {
3225- HasSmartPtrField = true ;
3226- break ;
3227- }
3228- }
3229- if (!HasSmartPtrField)
3252+ if (!isRvalueByValueRecordWithSmartPtr (AE))
32303253 continue ;
32313254
32323255 // Find a region for the argument.
@@ -3237,32 +3260,26 @@ void MallocChecker::checkPostCall(const CallEvent &Call,
32373260
32383261 const MemRegion *Base = RCall ? RCall : RExpr;
32393262 if (!Base) {
3240- // Fallback: if we have a by-value record with unique_ptr fields but no
3263+ // Fallback: if we have a by-value record with smart pointer fields but no
32413264 // region, mark all allocated symbols as escaped
32423265 ProgramStateRef State = C.getState ();
3243- RegionStateTy RS = State->get <RegionState>();
3244- ProgramStateRef NewState = State;
3245- for (auto [Sym, RefSt] : RS) {
3246- if (RefSt.isAllocated () || RefSt.isAllocatedOfSizeZero ()) {
3247- NewState =
3248- NewState->set <RegionState>(Sym, RefState::getEscaped (&RefSt));
3249- }
3250- }
3266+ ProgramStateRef NewState = escapeAllAllocatedSymbols (State);
32513267 if (NewState != State)
32523268 C.addTransition (NewState);
32533269 continue ;
32543270 }
32553271
32563272 // Push direct smart owning pointer field regions only (precise root set).
3257- collectDirectSmartOwningPtrFieldRegions (Base, T, C, SmartPtrFieldRoots);
3273+ collectDirectSmartOwningPtrFieldRegions (Base, AE->getType (), C,
3274+ SmartPtrFieldRoots);
32583275 }
32593276
32603277 // Escape only from those field roots; do nothing if empty.
32613278 if (!SmartPtrFieldRoots.empty ()) {
32623279 ProgramStateRef State = C.getState ();
3263- auto Scan =
3264- State-> scanReachableSymbols < EscapeTrackedCallback>(SmartPtrFieldRoots);
3265- ProgramStateRef NewState = Scan. getState ( );
3280+ ProgramStateRef NewState =
3281+ EscapeTrackedCallback::EscapeTrackedRegionsReachableFrom (
3282+ SmartPtrFieldRoots, State );
32663283 if (NewState != State) {
32673284 C.addTransition (NewState);
32683285 } else {
@@ -3276,44 +3293,15 @@ void MallocChecker::checkPostCall(const CallEvent &Call,
32763293 continue ;
32773294 AE = AE->IgnoreParenImpCasts ();
32783295
3279- if (AE->isGLValue ())
3280- continue ;
3281- QualType T = AE->getType ();
3282- if (!T->isRecordType () || T->isReferenceType ())
3283- continue ;
3284-
3285- const bool LooksLikeTemp =
3286- isa<CXXTemporaryObjectExpr>(AE) ||
3287- isa<MaterializeTemporaryExpr>(AE) || isa<CXXConstructExpr>(AE) ||
3288- isa<InitListExpr>(AE) || isa<ImplicitCastExpr>(AE) ||
3289- isa<CXXBindTemporaryExpr>(AE);
3290- if (!LooksLikeTemp)
3291- continue ;
3292-
3293- // Check if this record type has smart pointer fields
3294- const auto *CRD = T->getAsCXXRecordDecl ();
3295- if (CRD) {
3296- for (const FieldDecl *FD : CRD->fields ()) {
3297- if (isSmartOwningPtrType (FD->getType ())) {
3298- hasByValueRecordWithSmartPtr = true ;
3299- break ;
3300- }
3301- }
3302- }
3303- if (hasByValueRecordWithSmartPtr)
3296+ if (isRvalueByValueRecordWithSmartPtr (AE)) {
3297+ hasByValueRecordWithSmartPtr = true ;
33043298 break ;
3299+ }
33053300 }
33063301
33073302 if (hasByValueRecordWithSmartPtr) {
33083303 ProgramStateRef State = C.getState ();
3309- RegionStateTy RS = State->get <RegionState>();
3310- ProgramStateRef NewState = State;
3311- for (auto [Sym, RefSt] : RS) {
3312- if (RefSt.isAllocated () || RefSt.isAllocatedOfSizeZero ()) {
3313- NewState =
3314- NewState->set <RegionState>(Sym, RefState::getEscaped (&RefSt));
3315- }
3316- }
3304+ ProgramStateRef NewState = escapeAllAllocatedSymbols (State);
33173305 if (NewState != State)
33183306 C.addTransition (NewState);
33193307 }
@@ -3439,7 +3427,6 @@ void MallocChecker::checkEscapeOnReturn(const ReturnStmt *S,
34393427 if (!Sym)
34403428 // If we are returning a field of the allocated struct or an array element,
34413429 // the callee could still free the memory.
3442- // TODO: This logic should be a part of generic symbol escape callback.
34433430 if (const MemRegion *MR = RetVal.getAsRegion ())
34443431 if (isa<FieldRegion, ElementRegion>(MR))
34453432 if (const SymbolicRegion *BMR =
0 commit comments