@@ -729,26 +729,71 @@ void Sema::NarrowVariableToNonNull(const VarDecl *VD) {
729729 return ;
730730
731731 QualType OrigType = VD->getType ();
732- std::optional<NullabilityKind> Nullability = OrigType->getNullability ();
733732
734- // Only narrow if the type is nullable
735- if (!Nullability || *Nullability != NullabilityKind::Nullable)
733+ QualType CurrentType = GetNarrowedType (VD);
734+ if (CurrentType.isNull ())
735+ CurrentType = OrigType;
736+
737+ std::optional<NullabilityKind> Nullability = CurrentType->getNullability ();
738+
739+ if (Nullability && *Nullability == NullabilityKind::NonNull)
736740 return ;
737741
738- // Strip the existing nullability attribute first
739- QualType BaseType = OrigType;
742+ QualType BaseType = CurrentType;
740743 (void )AttributedType::stripOuterNullability (BaseType);
741744
742- // Create a non-null version of the type
743745 QualType NarrowedType = Context.getAttributedType (
744746 NullabilityKind::NonNull,
745747 BaseType,
746748 BaseType);
747749
748- // Store it in the current scope
749750 NullabilityNarrowingScopes.back ()[VD] = NarrowedType;
750751}
751752
753+ void Sema::NarrowVariablePointeeToNonNull (const VarDecl *VD) {
754+ if (!getLangOpts ().StrictNullability )
755+ return ;
756+ if (!VD || NullabilityNarrowingScopes.empty ())
757+ return ;
758+
759+ QualType CurrentType = GetNarrowedType (VD);
760+ if (CurrentType.isNull ())
761+ CurrentType = VD->getType ();
762+
763+ const PointerType *PtrType = CurrentType->getAs <PointerType>();
764+ if (!PtrType)
765+ return ;
766+
767+ QualType PointeeType = PtrType->getPointeeType ();
768+ std::optional<NullabilityKind> PointeeNullability = PointeeType->getNullability ();
769+
770+ if (PointeeNullability && *PointeeNullability == NullabilityKind::NonNull)
771+ return ;
772+
773+ QualType PointeeBase = PointeeType;
774+ (void )AttributedType::stripOuterNullability (PointeeBase);
775+
776+ QualType NarrowedPointee = Context.getAttributedType (
777+ NullabilityKind::NonNull,
778+ PointeeBase,
779+ PointeeBase);
780+
781+ std::optional<NullabilityKind> OuterNullability = CurrentType->getNullability ();
782+ QualType OuterBase = CurrentType;
783+ (void )AttributedType::stripOuterNullability (OuterBase);
784+
785+ QualType NewPointerType = Context.getPointerType (NarrowedPointee);
786+
787+ if (OuterNullability) {
788+ NewPointerType = Context.getAttributedType (
789+ *OuterNullability,
790+ NewPointerType,
791+ NewPointerType);
792+ }
793+
794+ NullabilityNarrowingScopes.back ()[VD] = NewPointerType;
795+ }
796+
752797QualType Sema::GetNarrowedType (const VarDecl *VD) const {
753798 if (!getLangOpts ().StrictNullability )
754799 return QualType ();
@@ -826,10 +871,23 @@ const VarDecl* Sema::AnalyzeConditionForNullCheck(Expr *Cond, bool &IsNegated) {
826871 }
827872 }
828873
874+ // Handle: *ptr (dereferenced pointer)
875+ if (auto *UO = dyn_cast<UnaryOperator>(E)) {
876+ if (UO->getOpcode () == UO_Deref) {
877+ Expr *SubExpr = UO->getSubExpr ()->IgnoreParenImpCasts ();
878+ if (auto *DRE = dyn_cast<DeclRefExpr>(SubExpr)) {
879+ if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl ())) {
880+ if (VD->getType ()->isPointerType ()) {
881+ return VD;
882+ }
883+ }
884+ }
885+ }
886+ }
887+
829888 // Handle: ptr (implicit boolean conversion)
830889 if (auto *DRE = dyn_cast<DeclRefExpr>(E)) {
831890 if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl ())) {
832- // Check if this is a pointer type
833891 if (VD->getType ()->isPointerType ()) {
834892 return VD;
835893 }
@@ -882,23 +940,55 @@ void Sema::CollectAndCheckedVariables(Expr *Cond,
882940 // Check if this is an AND expression
883941 if (auto *BO = dyn_cast<BinaryOperator>(E)) {
884942 if (BO->getOpcode () == BO_LAnd) {
885- // Recursively collect from both sides of the AND
886943 CollectAndCheckedVariables (BO->getLHS (), Vars);
887944 CollectAndCheckedVariables (BO->getRHS (), Vars);
888945 return ;
889946 }
890947 }
891948
892- // Otherwise, check if this is a simple non-negated null check
893949 bool IsNegated = false ;
894950 const VarDecl *VD = AnalyzeConditionForNullCheck (Cond, IsNegated);
895951
896- // For AND patterns like "p && q", we want non-negated checks
897952 if (VD && !IsNegated) {
898953 Vars.push_back (VD);
899954 }
900955}
901956
957+ void Sema::CollectAndCheckedDereferences (Expr *Cond,
958+ SmallVectorImpl<const VarDecl*> &Vars) {
959+ if (!Cond)
960+ return ;
961+
962+ Expr *E = Cond->IgnoreParenImpCasts ();
963+
964+ if (auto *BO = dyn_cast<BinaryOperator>(E)) {
965+ if (BO->getOpcode () == BO_LAnd) {
966+ CollectAndCheckedDereferences (BO->getLHS (), Vars);
967+ CollectAndCheckedDereferences (BO->getRHS (), Vars);
968+ return ;
969+ }
970+ }
971+
972+ if (auto *UO = dyn_cast<UnaryOperator>(E)) {
973+ if (UO->getOpcode () == UO_LNot) {
974+ E = UO->getSubExpr ()->IgnoreParenImpCasts ();
975+ }
976+ }
977+
978+ if (auto *UO = dyn_cast<UnaryOperator>(E)) {
979+ if (UO->getOpcode () == UO_Deref) {
980+ Expr *SubExpr = UO->getSubExpr ()->IgnoreParenImpCasts ();
981+ if (auto *DRE = dyn_cast<DeclRefExpr>(SubExpr)) {
982+ if (auto *VD = dyn_cast<VarDecl>(DRE->getDecl ())) {
983+ if (VD->getType ()->isPointerType ()) {
984+ Vars.push_back (VD);
985+ }
986+ }
987+ }
988+ }
989+ }
990+ }
991+
902992// strict-nullability: Check if an expression contains any function calls that could have side effects
903993// This is used to determine if it's safe to apply narrowing from a condition
904994static bool ContainsSideEffectingCall (Expr *E) {
0 commit comments