Skip to content

Commit 11b7358

Browse files
committed
Fix multi-level pointer narrowing
1 parent ac90a85 commit 11b7358

File tree

7 files changed

+205
-42
lines changed

7 files changed

+205
-42
lines changed

clang/include/clang/Sema/Sema.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,8 @@ class Sema final : public SemaBase {
11841184
/// strict-nullability: Narrow a variable's type to non-null in the current scope.
11851185
void NarrowVariableToNonNull(const VarDecl *VD);
11861186

1187+
void NarrowVariablePointeeToNonNull(const VarDecl *VD);
1188+
11871189
/// strict-nullability: Get the narrowed type for a variable, if any.
11881190
/// Returns the narrowed type if the variable has been narrowed in the
11891191
/// current scope, otherwise returns an empty QualType.
@@ -1206,6 +1208,9 @@ class Sema final : public SemaBase {
12061208
void CollectAndCheckedVariables(Expr *Cond,
12071209
SmallVectorImpl<const VarDecl*> &Vars);
12081210

1211+
void CollectAndCheckedDereferences(Expr *Cond,
1212+
SmallVectorImpl<const VarDecl*> &Vars);
1213+
12091214
/// strict-nullability: Check if a condition expression contains any function calls.
12101215
/// This is used to determine if it's safe to apply narrowing from a condition,
12111216
/// since function calls can invalidate pointers before we apply narrowing.

clang/lib/Parse/ParseStmt.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1557,7 +1557,13 @@ StmtResult Parser::ParseIfStatement(SourceLocation *TrailingElseLoc) {
15571557
for (const VarDecl *VD : Actions.AndExprCheckedVars) {
15581558
Actions.NarrowVariableToNonNull(VD);
15591559
}
1560-
Actions.AndExprCheckedVars.clear(); // Consume the list
1560+
Actions.AndExprCheckedVars.clear();
1561+
1562+
SmallVector<const VarDecl*, 8> DerefCheckedVars;
1563+
Actions.CollectAndCheckedDereferences(Cond.get().second, DerefCheckedVars);
1564+
for (const VarDecl *VD : DerefCheckedVars) {
1565+
Actions.NarrowVariablePointeeToNonNull(VD);
1566+
}
15611567

15621568
// Also narrow the traditional null-checked variable (but not if condition has calls)
15631569
if (CheckedVar && !IsNegatedCheck && !ConditionHasCalls) {
@@ -1902,6 +1908,12 @@ StmtResult Parser::ParseWhileStatement(SourceLocation *TrailingElseLoc,
19021908
Actions.NarrowVariableToNonNull(VD);
19031909
}
19041910

1911+
SmallVector<const VarDecl*, 8> DerefCheckedVars;
1912+
Actions.CollectAndCheckedDereferences(Cond.get().second, DerefCheckedVars);
1913+
for (const VarDecl *VD : DerefCheckedVars) {
1914+
Actions.NarrowVariablePointeeToNonNull(VD);
1915+
}
1916+
19051917
// Handle dereferences in condition: while (*p == 'x')
19061918
SmallVector<const VarDecl*, 4> DereferencedVars;
19071919
Actions.CollectDereferencedVariables(Cond.get().second, DereferencedVars);
@@ -2409,6 +2421,12 @@ StmtResult Parser::ParseForStatement(SourceLocation *TrailingElseLoc,
24092421
Actions.NarrowVariableToNonNull(VD);
24102422
}
24112423

2424+
SmallVector<const VarDecl*, 8> DerefCheckedVars;
2425+
Actions.CollectAndCheckedDereferences(SecondPart.get().second, DerefCheckedVars);
2426+
for (const VarDecl *VD : DerefCheckedVars) {
2427+
Actions.NarrowVariablePointeeToNonNull(VD);
2428+
}
2429+
24122430
// Handle dereferences in condition: for (; *p == 'x'; )
24132431
SmallVector<const VarDecl*, 4> DereferencedVars;
24142432
Actions.CollectDereferencedVariables(SecondPart.get().second, DereferencedVars);

clang/lib/Sema/Sema.cpp

Lines changed: 101 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -729,26 +729,71 @@ void Sema::NarrowVariableToNonNull(const VarDecl *VD) {
729729
return;
730730

731731
QualType OrigType = VD->getType();
732-
std::optional<NullabilityKind> Nullability = OrigType->getNullability();
733732

734-
// Only narrow if the type is nullable
735-
if (!Nullability || *Nullability != NullabilityKind::Nullable)
733+
QualType CurrentType = GetNarrowedType(VD);
734+
if (CurrentType.isNull())
735+
CurrentType = OrigType;
736+
737+
std::optional<NullabilityKind> Nullability = CurrentType->getNullability();
738+
739+
if (Nullability && *Nullability == NullabilityKind::NonNull)
736740
return;
737741

738-
// Strip the existing nullability attribute first
739-
QualType BaseType = OrigType;
742+
QualType BaseType = CurrentType;
740743
(void)AttributedType::stripOuterNullability(BaseType);
741744

742-
// Create a non-null version of the type
743745
QualType NarrowedType = Context.getAttributedType(
744746
NullabilityKind::NonNull,
745747
BaseType,
746748
BaseType);
747749

748-
// Store it in the current scope
749750
NullabilityNarrowingScopes.back()[VD] = NarrowedType;
750751
}
751752

753+
void Sema::NarrowVariablePointeeToNonNull(const VarDecl *VD) {
754+
if (!getLangOpts().StrictNullability)
755+
return;
756+
if (!VD || NullabilityNarrowingScopes.empty())
757+
return;
758+
759+
QualType CurrentType = GetNarrowedType(VD);
760+
if (CurrentType.isNull())
761+
CurrentType = VD->getType();
762+
763+
const PointerType *PtrType = CurrentType->getAs<PointerType>();
764+
if (!PtrType)
765+
return;
766+
767+
QualType PointeeType = PtrType->getPointeeType();
768+
std::optional<NullabilityKind> PointeeNullability = PointeeType->getNullability();
769+
770+
if (PointeeNullability && *PointeeNullability == NullabilityKind::NonNull)
771+
return;
772+
773+
QualType PointeeBase = PointeeType;
774+
(void)AttributedType::stripOuterNullability(PointeeBase);
775+
776+
QualType NarrowedPointee = Context.getAttributedType(
777+
NullabilityKind::NonNull,
778+
PointeeBase,
779+
PointeeBase);
780+
781+
std::optional<NullabilityKind> OuterNullability = CurrentType->getNullability();
782+
QualType OuterBase = CurrentType;
783+
(void)AttributedType::stripOuterNullability(OuterBase);
784+
785+
QualType NewPointerType = Context.getPointerType(NarrowedPointee);
786+
787+
if (OuterNullability) {
788+
NewPointerType = Context.getAttributedType(
789+
*OuterNullability,
790+
NewPointerType,
791+
NewPointerType);
792+
}
793+
794+
NullabilityNarrowingScopes.back()[VD] = NewPointerType;
795+
}
796+
752797
QualType Sema::GetNarrowedType(const VarDecl *VD) const {
753798
if (!getLangOpts().StrictNullability)
754799
return QualType();
@@ -826,10 +871,23 @@ const VarDecl* Sema::AnalyzeConditionForNullCheck(Expr *Cond, bool &IsNegated) {
826871
}
827872
}
828873

874+
// Handle: *ptr (dereferenced pointer)
875+
if (auto *UO = dyn_cast<UnaryOperator>(E)) {
876+
if (UO->getOpcode() == UO_Deref) {
877+
Expr *SubExpr = UO->getSubExpr()->IgnoreParenImpCasts();
878+
if (auto *DRE = dyn_cast<DeclRefExpr>(SubExpr)) {
879+
if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
880+
if (VD->getType()->isPointerType()) {
881+
return VD;
882+
}
883+
}
884+
}
885+
}
886+
}
887+
829888
// Handle: ptr (implicit boolean conversion)
830889
if (auto *DRE = dyn_cast<DeclRefExpr>(E)) {
831890
if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
832-
// Check if this is a pointer type
833891
if (VD->getType()->isPointerType()) {
834892
return VD;
835893
}
@@ -882,23 +940,55 @@ void Sema::CollectAndCheckedVariables(Expr *Cond,
882940
// Check if this is an AND expression
883941
if (auto *BO = dyn_cast<BinaryOperator>(E)) {
884942
if (BO->getOpcode() == BO_LAnd) {
885-
// Recursively collect from both sides of the AND
886943
CollectAndCheckedVariables(BO->getLHS(), Vars);
887944
CollectAndCheckedVariables(BO->getRHS(), Vars);
888945
return;
889946
}
890947
}
891948

892-
// Otherwise, check if this is a simple non-negated null check
893949
bool IsNegated = false;
894950
const VarDecl *VD = AnalyzeConditionForNullCheck(Cond, IsNegated);
895951

896-
// For AND patterns like "p && q", we want non-negated checks
897952
if (VD && !IsNegated) {
898953
Vars.push_back(VD);
899954
}
900955
}
901956

957+
void Sema::CollectAndCheckedDereferences(Expr *Cond,
958+
SmallVectorImpl<const VarDecl*> &Vars) {
959+
if (!Cond)
960+
return;
961+
962+
Expr *E = Cond->IgnoreParenImpCasts();
963+
964+
if (auto *BO = dyn_cast<BinaryOperator>(E)) {
965+
if (BO->getOpcode() == BO_LAnd) {
966+
CollectAndCheckedDereferences(BO->getLHS(), Vars);
967+
CollectAndCheckedDereferences(BO->getRHS(), Vars);
968+
return;
969+
}
970+
}
971+
972+
if (auto *UO = dyn_cast<UnaryOperator>(E)) {
973+
if (UO->getOpcode() == UO_LNot) {
974+
E = UO->getSubExpr()->IgnoreParenImpCasts();
975+
}
976+
}
977+
978+
if (auto *UO = dyn_cast<UnaryOperator>(E)) {
979+
if (UO->getOpcode() == UO_Deref) {
980+
Expr *SubExpr = UO->getSubExpr()->IgnoreParenImpCasts();
981+
if (auto *DRE = dyn_cast<DeclRefExpr>(SubExpr)) {
982+
if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
983+
if (VD->getType()->isPointerType()) {
984+
Vars.push_back(VD);
985+
}
986+
}
987+
}
988+
}
989+
}
990+
}
991+
902992
// strict-nullability: Check if an expression contains any function calls that could have side effects
903993
// This is used to determine if it's safe to apply narrowing from a condition
904994
static bool ContainsSideEffectingCall(Expr *E) {

clang/lib/Sema/SemaExpr.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14900,35 +14900,41 @@ static QualType CheckIndirectionOperand(Sema &S, Expr *Op, ExprValueKind &VK,
1490014900
if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
1490114901
QualType NarrowedType = S.GetNarrowedType(VD);
1490214902
if (!NarrowedType.isNull()) {
14903-
// Use the narrowed type for nullability checking
1490414903
CheckType = NarrowedType;
1490514904
}
1490614905
}
1490714906
}
14908-
// strict-nullability: Also check if this is an increment/decrement of a narrowed variable
14909-
// For *p++, Op is the UnaryOperator(++), and its sub-expression is the variable
1491014907
else if (const auto *UO = dyn_cast<UnaryOperator>(Op->IgnoreParenImpCasts())) {
14911-
if (UO->getOpcode() == UO_PostInc || UO->getOpcode() == UO_PreInc ||
14908+
if (UO->getOpcode() == UO_Deref) {
14909+
Expr *DerefSubExpr = UO->getSubExpr()->IgnoreParenImpCasts();
14910+
if (const auto *DRE = dyn_cast<DeclRefExpr>(DerefSubExpr)) {
14911+
if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
14912+
QualType NarrowedType = S.GetNarrowedType(VD);
14913+
if (!NarrowedType.isNull()) {
14914+
if (const PointerType *PT = NarrowedType->getAs<PointerType>()) {
14915+
CheckType = PT->getPointeeType();
14916+
}
14917+
}
14918+
}
14919+
}
14920+
}
14921+
else if (UO->getOpcode() == UO_PostInc || UO->getOpcode() == UO_PreInc ||
1491214922
UO->getOpcode() == UO_PostDec || UO->getOpcode() == UO_PreDec) {
14913-
// Get the variable being incremented/decremented
1491414923
if (const auto *DRE = dyn_cast<DeclRefExpr>(UO->getSubExpr()->IgnoreParenImpCasts())) {
1491514924
if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
1491614925
QualType NarrowedType = S.GetNarrowedType(VD);
1491714926
if (!NarrowedType.isNull()) {
14918-
// The result of p++ is the old value of p, which had the narrowed type
1491914927
CheckType = NarrowedType;
1492014928
}
1492114929
}
1492214930
}
1492314931
}
1492414932
}
1492514933

14926-
if (auto Nullability = CheckType->getNullability()) {
14927-
if (*Nullability == NullabilityKind::Nullable) {
14928-
// strict-nullability: Warn about dereferencing nullable pointers.
14929-
// Dereferencing does NOT perform a null-check - it will crash if null!
14930-
S.Diag(OpLoc, diag::warn_strict_nullability_dereference) << OpTy;
14931-
}
14934+
auto Nullability = CheckType->getNullability();
14935+
if (!Nullability || *Nullability == NullabilityKind::Nullable ||
14936+
*Nullability == NullabilityKind::Unspecified) {
14937+
S.Diag(OpLoc, diag::warn_strict_nullability_dereference) << OpTy;
1493214938
}
1493314939
}
1493414940

clang/test/Sema/strict-nullability.c

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,6 @@ void test_chained_deref(int** pp) {
404404
**pp = 42; // OK - both levels narrowed
405405
}
406406
}
407-
408-
// Test: Triple pointers
409-
void test_triple_pointers(int*** ppp) {
410-
if (ppp && *ppp && **ppp) {
411-
***ppp = 42; // OK - all three levels narrowed
412-
}
413-
}
414-
415407
// Test: Declared nonnull multi-level
416408
void test_nonnull_outer_ptr(int* * _Nonnull pp) {
417409
*pp = 0; // OK - pp is nonnull by declaration
@@ -627,3 +619,31 @@ void test_volatile_with_check(void) {
627619
}
628620
}
629621

622+
void test_double_deref_no_check(int** pp) {
623+
**pp = 42; // expected-warning{{dereferencing nullable pointer of type 'int **'}}
624+
// expected-warning@-1{{dereferencing nullable pointer of type 'int *'}}
625+
}
626+
627+
void test_double_deref_both_checked(int** pp) {
628+
if (pp && *pp) {
629+
**pp = 42;
630+
}
631+
}
632+
633+
void test_double_deref_only_outer_checked(int** pp) {
634+
if (pp) {
635+
**pp = 42; // expected-warning{{dereferencing nullable pointer of type 'int *'}}
636+
}
637+
}
638+
639+
void test_double_deref_nonnull_inner(int* _Nonnull * pp) {
640+
if (pp) {
641+
**pp = 42;
642+
}
643+
}
644+
645+
void test_double_deref_nonnull(int* _Nonnull *_Nonnull pp) {
646+
**pp = 42;
647+
}
648+
649+
Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1-
// Multi-level pointers require careful handling
2-
void deref_twice(int** pp) {
3-
**pp = 42; // warning - *pp might be null!
1+
void deref_twice_unsafe(int** pp) {
2+
**pp = 42; // Error
43
}
54

6-
void safe_deref(int** pp) {
5+
void deref_twice_safe(int** pp) {
76
if (pp && *pp) {
87
**pp = 42; // OK
98
}
9+
}
10+
11+
void deref_twice_partial(int** pp) {
12+
if (pp) {
13+
**pp = 42; // Error
14+
}
15+
}
16+
17+
void example_nonnull_inner(int * _Nonnull * _Nonnull pp) {
18+
**pp = 42; // OK
1019
}
Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,24 @@
1-
// _Nonnull annotations guarantee non-null
2-
void process(_Nonnull int* data) {
3-
*data = 42; // OK - data is guaranteed non-null
1+
void process(int* _Nonnull data) {
2+
*data = 42;
43
}
54

6-
void example(int* _Nullable x, int* _Nonnull y) {
7-
process(x); // warning - passing nullable to nonnull
8-
process(y); // OK
5+
void example(int* _Nullable x, int* _Nonnull y) {
6+
process(x);
7+
process(y);
8+
}
9+
10+
void deref_twice_unsafe(int* _Nullable * _Nullable pp) {
11+
**pp = 42;
12+
}
13+
14+
void deref_twice_safe(int* _Nullable * _Nullable pp) {
15+
if (pp && *pp) {
16+
**pp = 42;
17+
}
18+
}
19+
20+
void deref_twice_partial(int* _Nullable * _Nullable pp) {
21+
if (pp) {
22+
**pp = 42;
23+
}
924
}

0 commit comments

Comments
 (0)