@@ -1989,7 +1989,7 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) {
19891989}
19901990
19911991// Helper function for CheckHLSLBuiltinFunctionCall
1992- static bool CheckVectorElementCallArgs (Sema *S, CallExpr *TheCall) {
1992+ static bool CheckVectorElementCallArgs (Sema *S, CallExpr *TheCall, unsigned NumArgs ) {
19931993 assert (TheCall->getNumArgs () > 1 );
19941994 ExprResult A = TheCall->getArg (0 );
19951995
@@ -1999,7 +1999,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
19991999 SourceLocation BuiltinLoc = TheCall->getBeginLoc ();
20002000
20012001 bool AllBArgAreVectors = true ;
2002- for (unsigned i = 1 ; i < TheCall-> getNumArgs () ; ++i) {
2002+ for (unsigned i = 1 ; i < NumArgs ; ++i) {
20032003 ExprResult B = TheCall->getArg (i);
20042004 QualType ArgTyB = B.get ()->getType ();
20052005 auto *VecTyB = ArgTyB->getAs <VectorType>();
@@ -2050,6 +2050,10 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) {
20502050 return false ;
20512051}
20522052
2053+ static bool CheckVectorElementCallArgs (Sema *S, CallExpr *TheCall) {
2054+ return CheckVectorElementCallArgs (S, TheCall, TheCall->getNumArgs ());
2055+ }
2056+
20532057static bool CheckAllArgsHaveSameType (Sema *S, CallExpr *TheCall) {
20542058 assert (TheCall->getNumArgs () > 1 );
20552059 QualType ArgTy0 = TheCall->getArg (0 )->getType ();
@@ -2092,10 +2096,10 @@ static bool CheckArgTypeIsCorrect(
20922096 return false ;
20932097}
20942098
2095- static bool CheckAllArgTypesAreCorrect (
2096- Sema *S, CallExpr *TheCall, QualType ExpectedType,
2099+ static bool CheckArgTypesAreCorrect (
2100+ Sema *S, CallExpr *TheCall, unsigned NumArgs, QualType ExpectedType,
20972101 llvm::function_ref<bool (clang::QualType PassedType)> Check) {
2098- for (unsigned i = 0 ; i < TheCall-> getNumArgs () ; ++i) {
2102+ for (unsigned i = 0 ; i < NumArgs ; ++i) {
20992103 Expr *Arg = TheCall->getArg (i);
21002104 if (CheckArgTypeIsCorrect (S, Arg, ExpectedType, Check)) {
21012105 return true ;
@@ -2104,6 +2108,13 @@ static bool CheckAllArgTypesAreCorrect(
21042108 return false ;
21052109}
21062110
2111+ static bool CheckAllArgTypesAreCorrect (
2112+ Sema *S, CallExpr *TheCall, QualType ExpectedType,
2113+ llvm::function_ref<bool (clang::QualType PassedType)> Check) {
2114+ return CheckArgTypesAreCorrect (S, TheCall, TheCall->getNumArgs (),
2115+ ExpectedType, Check);
2116+ }
2117+
21072118static bool CheckAllArgsHaveFloatRepresentation (Sema *S, CallExpr *TheCall) {
21082119 auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
21092120 return !PassedType->hasFloatingRepresentation ();
@@ -2147,15 +2158,17 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
21472158 return true ;
21482159}
21492160
2150- static bool CheckNoDoubleVectors (Sema *S, CallExpr *TheCall) {
2161+ static bool CheckNoDoubleVectors (Sema *S, CallExpr *TheCall,
2162+ unsigned NumArgs, QualType ExpectedType) {
21512163 auto checkDoubleVector = [](clang::QualType PassedType) -> bool {
21522164 if (const auto *VecTy = PassedType->getAs <VectorType>())
21532165 return VecTy->getElementType ()->isDoubleType ();
21542166 return false ;
21552167 };
2156- return CheckAllArgTypesAreCorrect (S, TheCall, S-> Context . FloatTy ,
2157- checkDoubleVector);
2168+ return CheckArgTypesAreCorrect (S, TheCall, NumArgs,
2169+ ExpectedType, checkDoubleVector);
21582170}
2171+
21592172static bool CheckFloatingOrIntRepresentation (Sema *S, CallExpr *TheCall) {
21602173 auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
21612174 return !PassedType->hasIntegerRepresentation () &&
@@ -2471,7 +2484,21 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
24712484 return true ;
24722485 if (SemaRef.BuiltinVectorToScalarMath (TheCall))
24732486 return true ;
2474- if (CheckNoDoubleVectors (&SemaRef, TheCall))
2487+ if (CheckNoDoubleVectors (&SemaRef, TheCall,
2488+ TheCall->getNumArgs (), SemaRef.Context .FloatTy ))
2489+ return true ;
2490+ break ;
2491+ }
2492+ case Builtin::BI__builtin_hlsl_dot2add: {
2493+ if (SemaRef.checkArgCount (TheCall, 3 ))
2494+ return true ;
2495+ if (CheckVectorElementCallArgs (&SemaRef, TheCall, TheCall->getNumArgs () - 1 ))
2496+ return true ;
2497+ if (CheckArgTypeMatches (&SemaRef, TheCall->getArg (2 ), SemaRef.getASTContext ().FloatTy ))
2498+ return true ;
2499+ if (CheckNoDoubleVectors (&SemaRef, TheCall,
2500+ TheCall->getNumArgs () - 1 ,
2501+ SemaRef.Context .HalfTy ))
24752502 return true ;
24762503 break ;
24772504 }
0 commit comments