@@ -2493,15 +2493,27 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
2493
2493
// Check number of arguments should be 3
2494
2494
if (SemaRef.checkArgCount (TheCall, 3 ))
2495
2495
return true ;
2496
- // Check first two arguments should be vectors of same length
2497
- if (CheckVectorElementCallArgs (&SemaRef, TheCall, TheCall->getNumArgs () - 1 ))
2496
+
2497
+ // Check first two arguments are vector of length 2 with half data type
2498
+ auto checkHalfVectorOfSize2 = [](clang::QualType PassedType) -> bool {
2499
+ if (const auto *VecTy = PassedType->getAs <VectorType>())
2500
+ return !(VecTy->getNumElements () == 2 &&
2501
+ VecTy->getElementType ()->isHalfType ());
2498
2502
return true ;
2499
- if (CheckArgTypeMatches (&SemaRef, TheCall->getArg (2 ), SemaRef.getASTContext ().FloatTy ))
2503
+ };
2504
+ if (CheckArgTypeIsCorrect (&SemaRef, TheCall->getArg (0 ),
2505
+ SemaRef.getASTContext ().HalfTy ,
2506
+ checkHalfVectorOfSize2))
2500
2507
return true ;
2501
- if (CheckNoDoubleVectors (&SemaRef, TheCall,
2502
- TheCall->getNumArgs () - 1 ,
2503
- SemaRef.Context .HalfTy ))
2508
+ if (CheckArgTypeIsCorrect (&SemaRef, TheCall->getArg (1 ),
2509
+ SemaRef.getASTContext ().HalfTy ,
2510
+ checkHalfVectorOfSize2))
2511
+ return true ;
2512
+
2513
+ // Check third argument is a float
2514
+ if (CheckArgTypeMatches (&SemaRef, TheCall->getArg (2 ), SemaRef.getASTContext ().FloatTy ))
2504
2515
return true ;
2516
+ TheCall->setType (TheCall->getArg (2 )->getType ());
2505
2517
break ;
2506
2518
}
2507
2519
case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
0 commit comments