@@ -2493,15 +2493,27 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
24932493 // Check number of arguments should be 3
24942494 if (SemaRef.checkArgCount (TheCall, 3 ))
24952495 return true ;
2496- // Check first two arguments should be vectors of same length
2497- if (CheckVectorElementCallArgs (&SemaRef, TheCall, TheCall->getNumArgs () - 1 ))
2496+
2497+ // Check first two arguments are vector of length 2 with half data type
2498+ auto checkHalfVectorOfSize2 = [](clang::QualType PassedType) -> bool {
2499+ if (const auto *VecTy = PassedType->getAs <VectorType>())
2500+ return !(VecTy->getNumElements () == 2 &&
2501+ VecTy->getElementType ()->isHalfType ());
24982502 return true ;
2499- if (CheckArgTypeMatches (&SemaRef, TheCall->getArg (2 ), SemaRef.getASTContext ().FloatTy ))
2503+ };
2504+ if (CheckArgTypeIsCorrect (&SemaRef, TheCall->getArg (0 ),
2505+ SemaRef.getASTContext ().HalfTy ,
2506+ checkHalfVectorOfSize2))
25002507 return true ;
2501- if (CheckNoDoubleVectors (&SemaRef, TheCall,
2502- TheCall->getNumArgs () - 1 ,
2503- SemaRef.Context .HalfTy ))
2508+ if (CheckArgTypeIsCorrect (&SemaRef, TheCall->getArg (1 ),
2509+ SemaRef.getASTContext ().HalfTy ,
2510+ checkHalfVectorOfSize2))
2511+ return true ;
2512+
2513+ // Check third argument is a float
2514+ if (CheckArgTypeMatches (&SemaRef, TheCall->getArg (2 ), SemaRef.getASTContext ().FloatTy ))
25042515 return true ;
2516+ TheCall->setType (TheCall->getArg (2 )->getType ());
25052517 break ;
25062518 }
25072519 case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
0 commit comments