@@ -1970,26 +1970,40 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
19701970// Check if \p Ty is a valid type for the elementwise math builtins. If it is
19711971// not a valid type, emit an error message and return true. Otherwise return
19721972// false.
1973- static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
1974- QualType ArgTy, int ArgIndex) {
1975- if (!ArgTy->getAs<VectorType>() &&
1976- !ConstantMatrixType::isValidElementType(ArgTy)) {
1977- return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
1978- << ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
1979- }
1980-
1981- return false;
1982- }
1983-
1984- static bool checkFPMathBuiltinElementType(Sema &S, SourceLocation Loc,
1985- QualType ArgTy, int ArgIndex) {
1973+ static bool
1974+ checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
1975+ Sema::EltwiseBuiltinArgTyRestriction ArgTyRestr,
1976+ int ArgOrdinal) {
19861977 QualType EltTy = ArgTy;
19871978 if (auto *VecTy = EltTy->getAs<VectorType>())
19881979 EltTy = VecTy->getElementType();
19891980
1990- if (!EltTy->isRealFloatingType()) {
1991- return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
1992- << ArgIndex << /* vector or float ty*/ 5 << ArgTy;
1981+ switch (ArgTyRestr) {
1982+ case Sema::EltwiseBuiltinArgTyRestriction::None:
1983+ if (!ArgTy->getAs<VectorType>() &&
1984+ !ConstantMatrixType::isValidElementType(ArgTy)) {
1985+ return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
1986+ << ArgOrdinal << /* vector, integer or float ty*/ 0 << ArgTy;
1987+ }
1988+ break;
1989+ case Sema::EltwiseBuiltinArgTyRestriction::FloatTy:
1990+ if (!EltTy->isRealFloatingType()) {
1991+ return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
1992+ << ArgOrdinal << /* vector or float ty*/ 5 << ArgTy;
1993+ }
1994+ break;
1995+ case Sema::EltwiseBuiltinArgTyRestriction::IntegerTy:
1996+ if (!EltTy->isIntegerType()) {
1997+ return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
1998+ << ArgOrdinal << /* vector or int ty*/ 10 << ArgTy;
1999+ }
2000+ break;
2001+ case Sema::EltwiseBuiltinArgTyRestriction::SignedIntOrFloatTy:
2002+ if (EltTy->isUnsignedIntegerType()) {
2003+ return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
2004+ << 1 << /* signed integer or float ty*/ 3 << ArgTy;
2005+ }
2006+ break;
19932007 }
19942008
19952009 return false;
@@ -2695,23 +2709,11 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
26952709
26962710 // __builtin_elementwise_abs restricts the element type to signed integers or
26972711 // floating point types only.
2698- case Builtin::BI__builtin_elementwise_abs: {
2699- if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2712+ case Builtin::BI__builtin_elementwise_abs:
2713+ if (PrepareBuiltinElementwiseMathOneArgCall(
2714+ TheCall, EltwiseBuiltinArgTyRestriction::SignedIntOrFloatTy))
27002715 return ExprError();
2701-
2702- QualType ArgTy = TheCall->getArg(0)->getType();
2703- QualType EltTy = ArgTy;
2704-
2705- if (auto *VecTy = EltTy->getAs<VectorType>())
2706- EltTy = VecTy->getElementType();
2707- if (EltTy->isUnsignedIntegerType()) {
2708- Diag(TheCall->getArg(0)->getBeginLoc(),
2709- diag::err_builtin_invalid_arg_type)
2710- << 1 << /* signed integer or float ty*/ 3 << ArgTy;
2711- return ExprError();
2712- }
27132716 break;
2714- }
27152717
27162718 // These builtins restrict the element type to floating point
27172719 // types only.
@@ -2737,81 +2739,46 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
27372739 case Builtin::BI__builtin_elementwise_tan:
27382740 case Builtin::BI__builtin_elementwise_tanh:
27392741 case Builtin::BI__builtin_elementwise_trunc:
2740- case Builtin::BI__builtin_elementwise_canonicalize: {
2741- if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2742- return ExprError();
2743-
2744- QualType ArgTy = TheCall->getArg(0)->getType();
2745- if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(),
2746- ArgTy, 1))
2742+ case Builtin::BI__builtin_elementwise_canonicalize:
2743+ if (PrepareBuiltinElementwiseMathOneArgCall(
2744+ TheCall, EltwiseBuiltinArgTyRestriction::FloatTy))
27472745 return ExprError();
27482746 break;
2749- }
2750- case Builtin::BI__builtin_elementwise_fma: {
2747+ case Builtin::BI__builtin_elementwise_fma:
27512748 if (BuiltinElementwiseTernaryMath(TheCall))
27522749 return ExprError();
27532750 break;
2754- }
27552751
27562752 // These builtins restrict the element type to floating point
27572753 // types only, and take in two arguments.
27582754 case Builtin::BI__builtin_elementwise_minimum:
27592755 case Builtin::BI__builtin_elementwise_maximum:
27602756 case Builtin::BI__builtin_elementwise_atan2:
27612757 case Builtin::BI__builtin_elementwise_fmod:
2762- case Builtin::BI__builtin_elementwise_pow: {
2763- if (BuiltinElementwiseMath(TheCall, /*FPOnly=*/true))
2758+ case Builtin::BI__builtin_elementwise_pow:
2759+ if (BuiltinElementwiseMath(TheCall,
2760+ EltwiseBuiltinArgTyRestriction::FloatTy))
27642761 return ExprError();
27652762 break;
2766- }
2767-
27682763 // These builtins restrict the element type to integer
27692764 // types only.
27702765 case Builtin::BI__builtin_elementwise_add_sat:
2771- case Builtin::BI__builtin_elementwise_sub_sat: {
2772- if (BuiltinElementwiseMath(TheCall))
2773- return ExprError();
2774-
2775- const Expr *Arg = TheCall->getArg(0);
2776- QualType ArgTy = Arg->getType();
2777- QualType EltTy = ArgTy;
2778-
2779- if (auto *VecTy = EltTy->getAs<VectorType>())
2780- EltTy = VecTy->getElementType();
2781-
2782- if (!EltTy->isIntegerType()) {
2783- Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
2784- << 1 << /* integer ty */ 6 << ArgTy;
2766+ case Builtin::BI__builtin_elementwise_sub_sat:
2767+ if (BuiltinElementwiseMath(TheCall,
2768+ EltwiseBuiltinArgTyRestriction::IntegerTy))
27852769 return ExprError();
2786- }
27872770 break;
2788- }
2789-
27902771 case Builtin::BI__builtin_elementwise_min:
27912772 case Builtin::BI__builtin_elementwise_max:
27922773 if (BuiltinElementwiseMath(TheCall))
27932774 return ExprError();
27942775 break;
27952776 case Builtin::BI__builtin_elementwise_popcount:
2796- case Builtin::BI__builtin_elementwise_bitreverse: {
2797- if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2798- return ExprError();
2799-
2800- const Expr *Arg = TheCall->getArg(0);
2801- QualType ArgTy = Arg->getType();
2802- QualType EltTy = ArgTy;
2803-
2804- if (auto *VecTy = EltTy->getAs<VectorType>())
2805- EltTy = VecTy->getElementType();
2806-
2807- if (!EltTy->isIntegerType()) {
2808- Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
2809- << 1 << /* integer ty */ 6 << ArgTy;
2777+ case Builtin::BI__builtin_elementwise_bitreverse:
2778+ if (PrepareBuiltinElementwiseMathOneArgCall(
2779+ TheCall, EltwiseBuiltinArgTyRestriction::IntegerTy))
28102780 return ExprError();
2811- }
28122781 break;
2813- }
2814-
28152782 case Builtin::BI__builtin_elementwise_copysign: {
28162783 if (checkArgCount(TheCall, 2))
28172784 return ExprError();
@@ -2823,10 +2790,12 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
28232790
28242791 QualType MagnitudeTy = Magnitude.get()->getType();
28252792 QualType SignTy = Sign.get()->getType();
2826- if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(),
2827- MagnitudeTy, 1) ||
2828- checkFPMathBuiltinElementType(*this, TheCall->getArg(1)->getBeginLoc(),
2829- SignTy, 2)) {
2793+ if (checkMathBuiltinElementType(
2794+ *this, TheCall->getArg(0)->getBeginLoc(), MagnitudeTy,
2795+ EltwiseBuiltinArgTyRestriction::FloatTy, 1) ||
2796+ checkMathBuiltinElementType(
2797+ *this, TheCall->getArg(1)->getBeginLoc(), SignTy,
2798+ EltwiseBuiltinArgTyRestriction::FloatTy, 2)) {
28302799 return ExprError();
28312800 }
28322801
@@ -14662,7 +14631,8 @@ static ExprResult BuiltinVectorMathConversions(Sema &S, Expr *E) {
1466214631 return S.UsualUnaryFPConversions(Res.get());
1466314632}
1466414633
14665- bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
14634+ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(
14635+ CallExpr *TheCall, EltwiseBuiltinArgTyRestriction ArgTyRestr) {
1466614636 if (checkArgCount(TheCall, 1))
1466714637 return true;
1466814638
@@ -14673,15 +14643,17 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
1467314643 TheCall->setArg(0, A.get());
1467414644 QualType TyA = A.get()->getType();
1467514645
14676- if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
14646+ if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA,
14647+ ArgTyRestr, 1))
1467714648 return true;
1467814649
1467914650 TheCall->setType(TyA);
1468014651 return false;
1468114652}
1468214653
14683- bool Sema::BuiltinElementwiseMath(CallExpr *TheCall, bool FPOnly) {
14684- if (auto Res = BuiltinVectorMath(TheCall, FPOnly); Res.has_value()) {
14654+ bool Sema::BuiltinElementwiseMath(CallExpr *TheCall,
14655+ EltwiseBuiltinArgTyRestriction ArgTyRestr) {
14656+ if (auto Res = BuiltinVectorMath(TheCall, ArgTyRestr); Res.has_value()) {
1468514657 TheCall->setType(*Res);
1468614658 return false;
1468714659 }
@@ -14714,8 +14686,9 @@ static bool checkBuiltinVectorMathMixedEnums(Sema &S, Expr *LHS, Expr *RHS,
1471414686 return false;
1471514687}
1471614688
14717- std::optional<QualType> Sema::BuiltinVectorMath(CallExpr *TheCall,
14718- bool FPOnly) {
14689+ std::optional<QualType>
14690+ Sema::BuiltinVectorMath(CallExpr *TheCall,
14691+ EltwiseBuiltinArgTyRestriction ArgTyRestr) {
1471914692 if (checkArgCount(TheCall, 2))
1472014693 return std::nullopt;
1472114694
@@ -14736,26 +14709,21 @@ std::optional<QualType> Sema::BuiltinVectorMath(CallExpr *TheCall,
1473614709 QualType TyA = Args[0]->getType();
1473714710 QualType TyB = Args[1]->getType();
1473814711
14712+ if (checkMathBuiltinElementType(*this, LocA, TyA, ArgTyRestr, 1))
14713+ return std::nullopt;
14714+
1473914715 if (TyA.getCanonicalType() != TyB.getCanonicalType()) {
1474014716 Diag(LocA, diag::err_typecheck_call_different_arg_types) << TyA << TyB;
1474114717 return std::nullopt;
1474214718 }
1474314719
14744- if (FPOnly) {
14745- if (checkFPMathBuiltinElementType(*this, LocA, TyA, 1))
14746- return std::nullopt;
14747- } else {
14748- if (checkMathBuiltinElementType(*this, LocA, TyA, 1))
14749- return std::nullopt;
14750- }
14751-
1475214720 TheCall->setArg(0, Args[0]);
1475314721 TheCall->setArg(1, Args[1]);
1475414722 return TyA;
1475514723}
1475614724
14757- bool Sema::BuiltinElementwiseTernaryMath(CallExpr *TheCall,
14758- bool CheckForFloatArgs ) {
14725+ bool Sema::BuiltinElementwiseTernaryMath(
14726+ CallExpr *TheCall, EltwiseBuiltinArgTyRestriction ArgTyRestr ) {
1475914727 if (checkArgCount(TheCall, 3))
1476014728 return true;
1476114729
@@ -14775,20 +14743,11 @@ bool Sema::BuiltinElementwiseTernaryMath(CallExpr *TheCall,
1477514743 Args[I] = Converted.get();
1477614744 }
1477714745
14778- if (CheckForFloatArgs) {
14779- int ArgOrdinal = 1;
14780- for (Expr *Arg : Args) {
14781- if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(),
14782- Arg->getType(), ArgOrdinal++))
14783- return true;
14784- }
14785- } else {
14786- int ArgOrdinal = 1;
14787- for (Expr *Arg : Args) {
14788- if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
14789- ArgOrdinal++))
14790- return true;
14791- }
14746+ int ArgOrdinal = 1;
14747+ for (Expr *Arg : Args) {
14748+ if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
14749+ ArgTyRestr, ArgOrdinal++))
14750+ return true;
1479214751 }
1479314752
1479414753 for (int I = 1; I < 3; ++I) {
0 commit comments