@@ -2490,15 +2490,27 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
2490
2490
// Check number of arguments should be 3
2491
2491
if (SemaRef.checkArgCount (TheCall, 3 ))
2492
2492
return true ;
2493
- // Check first two arguments should be vectors of same length
2494
- if (CheckVectorElementCallArgs (&SemaRef, TheCall, TheCall->getNumArgs () - 1 ))
2493
+
2494
+ // Check first two arguments are vector of length 2 with half data type
2495
+ auto checkHalfVectorOfSize2 = [](clang::QualType PassedType) -> bool {
2496
+ if (const auto *VecTy = PassedType->getAs <VectorType>())
2497
+ return !(VecTy->getNumElements () == 2 &&
2498
+ VecTy->getElementType ()->isHalfType ());
2495
2499
return true ;
2496
- if (CheckArgTypeMatches (&SemaRef, TheCall->getArg (2 ), SemaRef.getASTContext ().FloatTy ))
2500
+ };
2501
+ if (CheckArgTypeIsCorrect (&SemaRef, TheCall->getArg (0 ),
2502
+ SemaRef.getASTContext ().HalfTy ,
2503
+ checkHalfVectorOfSize2))
2497
2504
return true ;
2498
- if (CheckNoDoubleVectors (&SemaRef, TheCall,
2499
- TheCall->getNumArgs () - 1 ,
2500
- SemaRef.Context .HalfTy ))
2505
+ if (CheckArgTypeIsCorrect (&SemaRef, TheCall->getArg (1 ),
2506
+ SemaRef.getASTContext ().HalfTy ,
2507
+ checkHalfVectorOfSize2))
2508
+ return true ;
2509
+
2510
+ // Check third argument is a float
2511
+ if (CheckArgTypeMatches (&SemaRef, TheCall->getArg (2 ), SemaRef.getASTContext ().FloatTy ))
2501
2512
return true ;
2513
+ TheCall->setType (TheCall->getArg (2 )->getType ());
2502
2514
break ;
2503
2515
}
2504
2516
case Builtin::BI__builtin_hlsl_elementwise_firstbithigh:
0 commit comments