2525#include " clang/Basic/SourceLocation.h"
2626#include " clang/Lex/Lexer.h"
2727#include " clang/Lex/Preprocessor.h"
28+ #include " llvm/ADT/APInt.h"
2829#include " llvm/ADT/APSInt.h"
2930#include " llvm/ADT/STLFunctionalExtras.h"
3031#include " llvm/ADT/SmallSet.h"
@@ -809,28 +810,86 @@ static bool hasUnsafeFormatOrSArg(const CallExpr *Call, const Expr *&UnsafeArg,
809810 const CallExpr *Call;
810811 unsigned FmtArgIdx;
811812 const Expr *&UnsafeArg;
813+ ASTContext &Ctx;
814+
815+ // Returns an `Expr` representing the precision if specified, null
816+ // otherwise.
817+ // The parameter `Call` is a printf call and the parameter `Precision` is
818+ // the precision of a format specifier of the `Call`.
819+ //
820+ // For example, for the `printf("%d, %.10s", 10, p)` call
821+ // `Precision` can be the precision of either "%d" or "%.10s". The former
822+ // one will have `NotSpecified` kind.
823+ const Expr *
824+ getPrecisionAsExpr (const analyze_printf::OptionalAmount &Precision,
825+ const CallExpr *Call) {
826+ unsigned PArgIdx = -1 ;
827+
828+ if (Precision.hasDataArgument ())
829+ PArgIdx = Precision.getPositionalArgIndex () + FmtArgIdx;
830+ if (0 < PArgIdx && PArgIdx < Call->getNumArgs ()) {
831+ const Expr *PArg = Call->getArg (PArgIdx);
832+
833+ // Strip the cast if `PArg` is a cast-to-int expression:
834+ if (auto *CE = dyn_cast<CastExpr>(PArg);
835+ CE && CE->getType ()->isSignedIntegerType ())
836+ PArg = CE->getSubExpr ();
837+ return PArg;
838+ }
839+ if (Precision.getHowSpecified () ==
840+ analyze_printf::OptionalAmount::HowSpecified::Constant) {
841+ auto SizeTy = Ctx.getSizeType ();
842+ llvm::APSInt PArgVal = llvm::APSInt (
843+ llvm::APInt (Ctx.getTypeSize (SizeTy), Precision.getConstantAmount ()),
844+ true );
845+
846+ return IntegerLiteral::Create (Ctx, PArgVal, Ctx.getSizeType (), {});
847+ }
848+ return nullptr ;
849+ }
812850
813851 public:
814852 StringFormatStringHandler (const CallExpr *Call, unsigned FmtArgIdx,
815- const Expr *&UnsafeArg)
816- : Call(Call), FmtArgIdx(FmtArgIdx), UnsafeArg(UnsafeArg) {}
853+ const Expr *&UnsafeArg, ASTContext &Ctx )
854+ : Call(Call), FmtArgIdx(FmtArgIdx), UnsafeArg(UnsafeArg), Ctx(Ctx) {}
817855
818856 bool HandlePrintfSpecifier (const analyze_printf::PrintfSpecifier &FS,
819857 const char *startSpecifier,
820858 unsigned specifierLen,
821859 const TargetInfo &Target) override {
822- if (FS.getConversionSpecifier ().getKind () ==
823- analyze_printf::PrintfConversionSpecifier::sArg ) {
824- unsigned ArgIdx = FS.getPositionalArgIndex () + FmtArgIdx;
825-
826- if (0 < ArgIdx && ArgIdx < Call->getNumArgs ())
827- if (!isNullTermPointer (Call->getArg (ArgIdx))) {
828- UnsafeArg = Call->getArg (ArgIdx); // output
829- // returning false stops parsing immediately
830- return false ;
831- }
832- }
833- return true ; // continue parsing
860+ if (FS.getConversionSpecifier ().getKind () !=
861+ analyze_printf::PrintfConversionSpecifier::sArg )
862+ return true ; // continue parsing
863+
864+ unsigned ArgIdx = FS.getPositionalArgIndex () + FmtArgIdx;
865+
866+ if (!(0 < ArgIdx && ArgIdx < Call->getNumArgs ()))
867+ // If the `ArgIdx` is invalid, give up.
868+ return true ; // continue parsing
869+
870+ const Expr *Arg = Call->getArg (ArgIdx);
871+
872+ if (isNullTermPointer (Arg))
873+ // If Arg is a null-terminated pointer, it is safe anyway.
874+ return true ; // continue parsing
875+
876+ // Otherwise, check if the specifier has a precision and if the character
877+ // pointer is safely bound by the precision:
878+ auto LengthModifier = FS.getLengthModifier ();
879+ QualType ArgType = Arg->getType ();
880+ bool IsArgTypeValid = // Is ArgType a character pointer type?
881+ ArgType->isPointerType () &&
882+ (LengthModifier.getKind () == LengthModifier.AsWideChar
883+ ? ArgType->getPointeeType ()->isWideCharType ()
884+ : ArgType->getPointeeType ()->isCharType ());
885+
886+ if (auto *Precision = getPrecisionAsExpr (FS.getPrecision (), Call);
887+ Precision && IsArgTypeValid)
888+ if (isPtrBufferSafe (Arg, Precision, Ctx))
889+ return true ;
890+ // Handle unsafe case:
891+ UnsafeArg = Call->getArg (ArgIdx); // output
892+ return false ; // returning false stops parsing immediately
834893 }
835894 };
836895
@@ -846,7 +905,7 @@ static bool hasUnsafeFormatOrSArg(const CallExpr *Call, const Expr *&UnsafeArg,
846905 else
847906 goto CHECK_UNSAFE_PTR;
848907
849- StringFormatStringHandler Handler (Call, FmtArgIdx, UnsafeArg);
908+ StringFormatStringHandler Handler (Call, FmtArgIdx, UnsafeArg, Ctx );
850909
851910 return analyze_format_string::ParsePrintfString (
852911 Handler, FmtStr.begin (), FmtStr.end (), Ctx.getLangOpts (),
0 commit comments