52
52
#include " clang/AST/DeclTemplate.h"
53
53
#include " clang/AST/Expr.h"
54
54
#include " clang/AST/ExprCXX.h"
55
- #include " clang/AST/TemplateBase.h"
56
- #include " clang/AST/Type.h"
57
55
58
56
#include " clang/AST/ParentMap.h"
59
57
#include " clang/ASTMatchers/ASTMatchFinder.h"
@@ -1113,17 +1111,23 @@ class StopTrackingCallback final : public SymbolVisitor {
1113
1111
// / The visitor traverses reachable symbols from a given set of memory regions
1114
1112
// / (typically smart pointer field regions) and marks any allocated symbols as
1115
1113
// / escaped. Escaped symbols are not reported as leaks by checkDeadSymbols.
1116
- // /
1117
- // / Usage:
1118
- // / auto Scan =
1119
- // / State->scanReachableSymbols<EscapeTrackedCallback>(RootRegions);
1120
- // / ProgramStateRef NewState = Scan.getState();
1121
- // / if (NewState != State) C.addTransition(NewState);
1122
1114
class EscapeTrackedCallback final : public SymbolVisitor {
1123
1115
ProgramStateRef State;
1124
1116
1125
- public:
1126
1117
explicit EscapeTrackedCallback (ProgramStateRef S) : State(std::move(S)) {}
1118
+
1119
+ public:
1120
+ // / Escape tracked regions reachable from the given roots.
1121
+ static ProgramStateRef
1122
+ EscapeTrackedRegionsReachableFrom (ArrayRef<const MemRegion *> Roots,
1123
+ ProgramStateRef State) {
1124
+ EscapeTrackedCallback Visitor (State);
1125
+ for (const MemRegion *R : Roots) {
1126
+ State->scanReachableSymbols (loc::MemRegionVal (R), Visitor);
1127
+ }
1128
+ return Visitor.getState ();
1129
+ }
1130
+
1127
1131
ProgramStateRef getState () const { return State; }
1128
1132
1129
1133
bool VisitSymbol (SymbolRef Sym) override {
@@ -3107,24 +3111,13 @@ void MallocChecker::checkDeadSymbols(SymbolReaper &SymReaper,
3107
3111
C.addTransition (state->set <RegionState>(RS), N);
3108
3112
}
3109
3113
3110
- static QualType canonicalStrip (QualType QT) {
3111
- return QT.getCanonicalType ().getUnqualifiedType ();
3112
- }
3113
-
3114
- static bool isInStdNamespace (const DeclContext *DC) {
3115
- while (DC) {
3116
- if (const auto *NS = dyn_cast<NamespaceDecl>(DC))
3117
- if (NS->isStdNamespace ())
3118
- return true ;
3119
- DC = DC->getParent ();
3120
- }
3121
- return false ;
3122
- }
3114
+ // Use isWithinStdNamespace from CheckerHelpers.h instead of custom
3115
+ // implementation
3123
3116
3124
3117
// Allowlist of owning smart pointers we want to recognize.
3125
3118
// Start with unique_ptr and shared_ptr. (intentionally exclude weak_ptr)
3126
3119
static bool isSmartOwningPtrType (QualType QT) {
3127
- QT = canonicalStrip (QT );
3120
+ QT = QT-> getCanonicalTypeUnqualified ( );
3128
3121
3129
3122
// First try TemplateSpecializationType (for std smart pointers)
3130
3123
const auto *TST = QT->getAs <TemplateSpecializationType>();
@@ -3138,30 +3131,75 @@ static bool isSmartOwningPtrType(QualType QT) {
3138
3131
return false ;
3139
3132
3140
3133
// Check if it's in std namespace
3141
- const DeclContext *DC = ND->getDeclContext ();
3142
- if (!isInStdNamespace (DC))
3134
+ if (!isWithinStdNamespace (ND))
3143
3135
return false ;
3144
3136
3145
3137
StringRef Name = ND->getName ();
3146
3138
return Name == " unique_ptr" || Name == " shared_ptr" ;
3147
3139
}
3148
3140
3149
3141
// Also try RecordType (for custom smart pointer implementations)
3150
- const auto *RT = QT->getAs <RecordType>();
3151
- if (RT) {
3152
- const auto *RD = RT->getDecl ();
3153
- if (RD) {
3154
- StringRef Name = RD->getName ();
3155
- if (Name == " unique_ptr" || Name == " shared_ptr" ) {
3156
- // Accept any custom unique_ptr or shared_ptr implementation
3157
- return true ;
3158
- }
3142
+ const auto *RD = QT->getAsCXXRecordDecl ();
3143
+ if (RD) {
3144
+ StringRef Name = RD->getName ();
3145
+ if (Name == " unique_ptr" || Name == " shared_ptr" ) {
3146
+ // Accept any custom unique_ptr or shared_ptr implementation
3147
+ return true ;
3159
3148
}
3160
3149
}
3161
3150
3162
3151
return false ;
3163
3152
}
3164
3153
3154
+ static bool hasSmartPtrField (const CXXRecordDecl *CRD) {
3155
+ // Check direct fields
3156
+ if (llvm::any_of (CRD->fields (), [](const FieldDecl *FD) {
3157
+ return isSmartOwningPtrType (FD->getType ());
3158
+ }))
3159
+ return true ;
3160
+
3161
+ // Check fields from base classes
3162
+ for (const CXXBaseSpecifier &Base : CRD->bases ()) {
3163
+ if (const CXXRecordDecl *BaseDecl = Base.getType ()->getAsCXXRecordDecl ()) {
3164
+ if (hasSmartPtrField (BaseDecl))
3165
+ return true ;
3166
+ }
3167
+ }
3168
+ return false ;
3169
+ }
3170
+
3171
+ static bool isRvalueByValueRecord (const Expr *AE) {
3172
+ if (AE->isGLValue ())
3173
+ return false ;
3174
+
3175
+ QualType T = AE->getType ();
3176
+ if (!T->isRecordType () || T->isReferenceType ())
3177
+ return false ;
3178
+
3179
+ // Accept common temp/construct forms but don't overfit.
3180
+ return isa<CXXTemporaryObjectExpr, MaterializeTemporaryExpr, CXXConstructExpr,
3181
+ InitListExpr, ImplicitCastExpr, CXXBindTemporaryExpr>(AE);
3182
+ }
3183
+
3184
+ static bool isRvalueByValueRecordWithSmartPtr (const Expr *AE) {
3185
+ if (!isRvalueByValueRecord (AE))
3186
+ return false ;
3187
+
3188
+ const auto *CRD = AE->getType ()->getAsCXXRecordDecl ();
3189
+ return CRD && hasSmartPtrField (CRD);
3190
+ }
3191
+
3192
+ static ProgramStateRef escapeAllAllocatedSymbols (ProgramStateRef State) {
3193
+ RegionStateTy RS = State->get <RegionState>();
3194
+ ProgramStateRef NewState = State;
3195
+ for (auto [Sym, RefSt] : RS) {
3196
+ if (RefSt.isAllocated () || RefSt.isAllocatedOfSizeZero ()) {
3197
+ NewState = NewState->set <RegionState>(Sym, RefState::getEscaped (&RefSt));
3198
+ }
3199
+ }
3200
+ return NewState;
3201
+ }
3202
+
3165
3203
static void collectDirectSmartOwningPtrFieldRegions (
3166
3204
const MemRegion *Base, QualType RecQT, CheckerContext &C,
3167
3205
SmallVectorImpl<const MemRegion *> &Out) {
@@ -3171,13 +3209,29 @@ static void collectDirectSmartOwningPtrFieldRegions(
3171
3209
if (!CRD)
3172
3210
return ;
3173
3211
3212
+ // Collect direct fields
3174
3213
for (const FieldDecl *FD : CRD->fields ()) {
3175
3214
if (!isSmartOwningPtrType (FD->getType ()))
3176
3215
continue ;
3177
3216
SVal L = C.getState ()->getLValue (FD, loc::MemRegionVal (Base));
3178
3217
if (const MemRegion *FR = L.getAsRegion ())
3179
3218
Out.push_back (FR);
3180
3219
}
3220
+
3221
+ // Collect fields from base classes
3222
+ for (const CXXBaseSpecifier &BaseSpec : CRD->bases ()) {
3223
+ if (const CXXRecordDecl *BaseDecl =
3224
+ BaseSpec.getType ()->getAsCXXRecordDecl ()) {
3225
+ // Get the base class region
3226
+ SVal BaseL = C.getState ()->getLValue (BaseDecl, Base->getAs <SubRegion>(),
3227
+ BaseSpec.isVirtual ());
3228
+ if (const MemRegion *BaseRegion = BaseL.getAsRegion ()) {
3229
+ // Recursively collect fields from this base class
3230
+ collectDirectSmartOwningPtrFieldRegions (BaseRegion, BaseSpec.getType (),
3231
+ C, Out);
3232
+ }
3233
+ }
3234
+ }
3181
3235
}
3182
3236
3183
3237
void MallocChecker::checkPostCall (const CallEvent &Call,
@@ -3195,38 +3249,7 @@ void MallocChecker::checkPostCall(const CallEvent &Call,
3195
3249
continue ;
3196
3250
AE = AE->IgnoreParenImpCasts ();
3197
3251
3198
- QualType T = AE->getType ();
3199
-
3200
- // **Relaxation 1**: accept *any rvalue* by-value record (not only strict
3201
- // PRVALUE).
3202
- if (AE->isGLValue ())
3203
- continue ;
3204
-
3205
- // By-value record only (no refs).
3206
- if (!T->isRecordType () || T->isReferenceType ())
3207
- continue ;
3208
-
3209
- // **Relaxation 2**: accept common temp/construct forms but don't overfit.
3210
- const bool LooksLikeTemp =
3211
- isa<CXXTemporaryObjectExpr>(AE) || isa<MaterializeTemporaryExpr>(AE) ||
3212
- isa<CXXConstructExpr>(AE) || isa<InitListExpr>(AE) ||
3213
- isa<ImplicitCastExpr>(AE) || // handle common rvalue materializations
3214
- isa<CXXBindTemporaryExpr>(AE); // handle CXXBindTemporaryExpr
3215
- if (!LooksLikeTemp)
3216
- continue ;
3217
-
3218
- // Require at least one direct smart owning pointer field by type.
3219
- const auto *CRD = T->getAsCXXRecordDecl ();
3220
- if (!CRD)
3221
- continue ;
3222
- bool HasSmartPtrField = false ;
3223
- for (const FieldDecl *FD : CRD->fields ()) {
3224
- if (isSmartOwningPtrType (FD->getType ())) {
3225
- HasSmartPtrField = true ;
3226
- break ;
3227
- }
3228
- }
3229
- if (!HasSmartPtrField)
3252
+ if (!isRvalueByValueRecordWithSmartPtr (AE))
3230
3253
continue ;
3231
3254
3232
3255
// Find a region for the argument.
@@ -3237,32 +3260,26 @@ void MallocChecker::checkPostCall(const CallEvent &Call,
3237
3260
3238
3261
const MemRegion *Base = RCall ? RCall : RExpr;
3239
3262
if (!Base) {
3240
- // Fallback: if we have a by-value record with unique_ptr fields but no
3263
+ // Fallback: if we have a by-value record with smart pointer fields but no
3241
3264
// region, mark all allocated symbols as escaped
3242
3265
ProgramStateRef State = C.getState ();
3243
- RegionStateTy RS = State->get <RegionState>();
3244
- ProgramStateRef NewState = State;
3245
- for (auto [Sym, RefSt] : RS) {
3246
- if (RefSt.isAllocated () || RefSt.isAllocatedOfSizeZero ()) {
3247
- NewState =
3248
- NewState->set <RegionState>(Sym, RefState::getEscaped (&RefSt));
3249
- }
3250
- }
3266
+ ProgramStateRef NewState = escapeAllAllocatedSymbols (State);
3251
3267
if (NewState != State)
3252
3268
C.addTransition (NewState);
3253
3269
continue ;
3254
3270
}
3255
3271
3256
3272
// Push direct smart owning pointer field regions only (precise root set).
3257
- collectDirectSmartOwningPtrFieldRegions (Base, T, C, SmartPtrFieldRoots);
3273
+ collectDirectSmartOwningPtrFieldRegions (Base, AE->getType (), C,
3274
+ SmartPtrFieldRoots);
3258
3275
}
3259
3276
3260
3277
// Escape only from those field roots; do nothing if empty.
3261
3278
if (!SmartPtrFieldRoots.empty ()) {
3262
3279
ProgramStateRef State = C.getState ();
3263
- auto Scan =
3264
- State-> scanReachableSymbols < EscapeTrackedCallback>(SmartPtrFieldRoots);
3265
- ProgramStateRef NewState = Scan. getState ( );
3280
+ ProgramStateRef NewState =
3281
+ EscapeTrackedCallback::EscapeTrackedRegionsReachableFrom (
3282
+ SmartPtrFieldRoots, State );
3266
3283
if (NewState != State) {
3267
3284
C.addTransition (NewState);
3268
3285
} else {
@@ -3276,44 +3293,15 @@ void MallocChecker::checkPostCall(const CallEvent &Call,
3276
3293
continue ;
3277
3294
AE = AE->IgnoreParenImpCasts ();
3278
3295
3279
- if (AE->isGLValue ())
3280
- continue ;
3281
- QualType T = AE->getType ();
3282
- if (!T->isRecordType () || T->isReferenceType ())
3283
- continue ;
3284
-
3285
- const bool LooksLikeTemp =
3286
- isa<CXXTemporaryObjectExpr>(AE) ||
3287
- isa<MaterializeTemporaryExpr>(AE) || isa<CXXConstructExpr>(AE) ||
3288
- isa<InitListExpr>(AE) || isa<ImplicitCastExpr>(AE) ||
3289
- isa<CXXBindTemporaryExpr>(AE);
3290
- if (!LooksLikeTemp)
3291
- continue ;
3292
-
3293
- // Check if this record type has smart pointer fields
3294
- const auto *CRD = T->getAsCXXRecordDecl ();
3295
- if (CRD) {
3296
- for (const FieldDecl *FD : CRD->fields ()) {
3297
- if (isSmartOwningPtrType (FD->getType ())) {
3298
- hasByValueRecordWithSmartPtr = true ;
3299
- break ;
3300
- }
3301
- }
3302
- }
3303
- if (hasByValueRecordWithSmartPtr)
3296
+ if (isRvalueByValueRecordWithSmartPtr (AE)) {
3297
+ hasByValueRecordWithSmartPtr = true ;
3304
3298
break ;
3299
+ }
3305
3300
}
3306
3301
3307
3302
if (hasByValueRecordWithSmartPtr) {
3308
3303
ProgramStateRef State = C.getState ();
3309
- RegionStateTy RS = State->get <RegionState>();
3310
- ProgramStateRef NewState = State;
3311
- for (auto [Sym, RefSt] : RS) {
3312
- if (RefSt.isAllocated () || RefSt.isAllocatedOfSizeZero ()) {
3313
- NewState =
3314
- NewState->set <RegionState>(Sym, RefState::getEscaped (&RefSt));
3315
- }
3316
- }
3304
+ ProgramStateRef NewState = escapeAllAllocatedSymbols (State);
3317
3305
if (NewState != State)
3318
3306
C.addTransition (NewState);
3319
3307
}
@@ -3439,7 +3427,6 @@ void MallocChecker::checkEscapeOnReturn(const ReturnStmt *S,
3439
3427
if (!Sym)
3440
3428
// If we are returning a field of the allocated struct or an array element,
3441
3429
// the callee could still free the memory.
3442
- // TODO: This logic should be a part of generic symbol escape callback.
3443
3430
if (const MemRegion *MR = RetVal.getAsRegion ())
3444
3431
if (isa<FieldRegion, ElementRegion>(MR))
3445
3432
if (const SymbolicRegion *BMR =
0 commit comments