@@ -1989,7 +1989,7 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
19891989}
19901990
19911991// Helper function for CheckHLSLBuiltinFunctionCall
1992- static bool CheckVectorElementCallArgs (Sema *S, CallExpr *TheCall) {
1992+ static bool CheckVectorElementCallArgs (Sema *S, CallExpr *TheCall, unsigned NumArgs ) {
19931993 assert (TheCall->getNumArgs () > 1 );
19941994 ExprResult A = TheCall->getArg (0 );
19951995
@@ -1999,7 +1999,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
19991999 SourceLocation BuiltinLoc = TheCall->getBeginLoc ();
20002000
20012001 bool AllBArgAreVectors = true ;
2002- for (unsigned i = 1 ; i < TheCall-> getNumArgs () ; ++i) {
2002+ for (unsigned i = 1 ; i < NumArgs ; ++i) {
20032003 ExprResult B = TheCall->getArg (i);
20042004 QualType ArgTyB = B.get ()->getType ();
20052005 auto *VecTyB = ArgTyB->getAs <VectorType>();
@@ -2049,6 +2049,10 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
20492049 return false ;
20502050}
20512051
2052+ static bool CheckVectorElementCallArgs (Sema *S, CallExpr *TheCall) {
2053+ return CheckVectorElementCallArgs (S, TheCall, TheCall->getNumArgs ());
2054+ }
2055+
20522056static bool CheckAllArgsHaveSameType (Sema *S, CallExpr *TheCall) {
20532057 assert (TheCall->getNumArgs () > 1 );
20542058 QualType ArgTy0 = TheCall->getArg (0 )->getType ();
@@ -2091,10 +2095,10 @@ static bool CheckArgTypeIsCorrect(
20912095 return false ;
20922096}
20932097
2094- static bool CheckAllArgTypesAreCorrect (
2095- Sema *S, CallExpr *TheCall, QualType ExpectedType,
2098+ static bool CheckArgTypesAreCorrect (
2099+ Sema *S, CallExpr *TheCall, unsigned NumArgs, QualType ExpectedType,
20962100 llvm::function_ref<bool (clang::QualType PassedType)> Check) {
2097- for (unsigned i = 0 ; i < TheCall-> getNumArgs () ; ++i) {
2101+ for (unsigned i = 0 ; i < NumArgs ; ++i) {
20982102 Expr *Arg = TheCall->getArg (i);
20992103 if (CheckArgTypeIsCorrect (S, Arg, ExpectedType, Check)) {
21002104 return true ;
@@ -2103,6 +2107,13 @@ static bool CheckAllArgTypesAreCorrect(
21032107 return false ;
21042108}
21052109
2110+ static bool CheckAllArgTypesAreCorrect (
2111+ Sema *S, CallExpr *TheCall, QualType ExpectedType,
2112+ llvm::function_ref<bool (clang::QualType PassedType)> Check) {
2113+ return CheckArgTypesAreCorrect (S, TheCall, TheCall->getNumArgs (),
2114+ ExpectedType, Check);
2115+ }
2116+
21062117static bool CheckAllArgsHaveFloatRepresentation (Sema *S, CallExpr *TheCall) {
21072118 auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
21082119 return !PassedType->hasFloatingRepresentation ();
@@ -2146,15 +2157,17 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
21462157 return true ;
21472158}
21482159
2149- static bool CheckNoDoubleVectors (Sema *S, CallExpr *TheCall) {
2160+ static bool CheckNoDoubleVectors (Sema *S, CallExpr *TheCall,
2161+ unsigned NumArgs, QualType ExpectedType) {
21502162 auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
21512163 if (const auto *VecTy = PassedType->getAs <VectorType>())
21522164 return VecTy->getElementType ()->isDoubleType ();
21532165 return false ;
21542166 };
2155- return CheckAllArgTypesAreCorrect (S, TheCall, S-> Context . FloatTy ,
2156- checkDoubleVector);
2167+ return CheckArgTypesAreCorrect (S, TheCall, NumArgs,
2168+ ExpectedType, checkDoubleVector);
21572169}
2170+
21582171static bool CheckFloatingOrIntRepresentation (Sema *S, CallExpr *TheCall) {
21592172 auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
21602173 return !PassedType->hasIntegerRepresentation () &&
@@ -2468,7 +2481,21 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
24682481 return true ;
24692482 if (SemaRef.BuiltinVectorToScalarMath (TheCall))
24702483 return true ;
2471- if (CheckNoDoubleVectors (&SemaRef, TheCall))
2484+ if (CheckNoDoubleVectors (&SemaRef, TheCall,
2485+ TheCall->getNumArgs (), SemaRef.Context .FloatTy ))
2486+ return true ;
2487+ break ;
2488+ }
2489+ case Builtin::BI__builtin_hlsl_dot2add: {
2490+ if (SemaRef.checkArgCount (TheCall, 3 ))
2491+ return true ;
2492+ if (CheckVectorElementCallArgs (&SemaRef, TheCall, TheCall->getNumArgs () - 1 ))
2493+ return true ;
2494+ if (CheckArgTypeMatches (&SemaRef, TheCall->getArg (2 ), SemaRef.getASTContext ().FloatTy ))
2495+ return true ;
2496+ if (CheckNoDoubleVectors (&SemaRef, TheCall,
2497+ TheCall->getNumArgs () - 1 ,
2498+ SemaRef.Context .HalfTy ))
24722499 return true ;
24732500 break ;
24742501 }
0 commit comments