@@ -1970,26 +1970,40 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
1970
1970
// Check if \p Ty is a valid type for the elementwise math builtins. If it is
1971
1971
// not a valid type, emit an error message and return true. Otherwise return
1972
1972
// false.
1973
- static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
1974
- QualType ArgTy, int ArgIndex) {
1975
- if (!ArgTy->getAs<VectorType>() &&
1976
- !ConstantMatrixType::isValidElementType(ArgTy)) {
1977
- return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
1978
- << ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
1979
- }
1980
-
1981
- return false;
1982
- }
1983
-
1984
- static bool checkFPMathBuiltinElementType(Sema &S, SourceLocation Loc,
1985
- QualType ArgTy, int ArgIndex) {
1973
+ static bool
1974
+ checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
1975
+ Sema::EltwiseBuiltinArgTyRestriction ArgTyRestr,
1976
+ int ArgOrdinal) {
1986
1977
QualType EltTy = ArgTy;
1987
1978
if (auto *VecTy = EltTy->getAs<VectorType>())
1988
1979
EltTy = VecTy->getElementType();
1989
1980
1990
- if (!EltTy->isRealFloatingType()) {
1991
- return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
1992
- << ArgIndex << /* vector or float ty*/ 5 << ArgTy;
1981
+ switch (ArgTyRestr) {
1982
+ case Sema::EltwiseBuiltinArgTyRestriction::None:
1983
+ if (!ArgTy->getAs<VectorType>() &&
1984
+ !ConstantMatrixType::isValidElementType(ArgTy)) {
1985
+ return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
1986
+ << ArgOrdinal << /* vector, integer or float ty*/ 0 << ArgTy;
1987
+ }
1988
+ break;
1989
+ case Sema::EltwiseBuiltinArgTyRestriction::FloatTy:
1990
+ if (!EltTy->isRealFloatingType()) {
1991
+ return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
1992
+ << ArgOrdinal << /* vector or float ty*/ 5 << ArgTy;
1993
+ }
1994
+ break;
1995
+ case Sema::EltwiseBuiltinArgTyRestriction::IntegerTy:
1996
+ if (!EltTy->isIntegerType()) {
1997
+ return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
1998
+ << ArgOrdinal << /* vector or int ty*/ 10 << ArgTy;
1999
+ }
2000
+ break;
2001
+ case Sema::EltwiseBuiltinArgTyRestriction::SignedIntOrFloatTy:
2002
+ if (EltTy->isUnsignedIntegerType()) {
2003
+ return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
2004
+ << 1 << /* signed integer or float ty*/ 3 << ArgTy;
2005
+ }
2006
+ break;
1993
2007
}
1994
2008
1995
2009
return false;
@@ -2695,23 +2709,11 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
2695
2709
2696
2710
// __builtin_elementwise_abs restricts the element type to signed integers or
2697
2711
// floating point types only.
2698
- case Builtin::BI__builtin_elementwise_abs: {
2699
- if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2712
+ case Builtin::BI__builtin_elementwise_abs:
2713
+ if (PrepareBuiltinElementwiseMathOneArgCall(
2714
+ TheCall, EltwiseBuiltinArgTyRestriction::SignedIntOrFloatTy))
2700
2715
return ExprError();
2701
-
2702
- QualType ArgTy = TheCall->getArg(0)->getType();
2703
- QualType EltTy = ArgTy;
2704
-
2705
- if (auto *VecTy = EltTy->getAs<VectorType>())
2706
- EltTy = VecTy->getElementType();
2707
- if (EltTy->isUnsignedIntegerType()) {
2708
- Diag(TheCall->getArg(0)->getBeginLoc(),
2709
- diag::err_builtin_invalid_arg_type)
2710
- << 1 << /* signed integer or float ty*/ 3 << ArgTy;
2711
- return ExprError();
2712
- }
2713
2716
break;
2714
- }
2715
2717
2716
2718
// These builtins restrict the element type to floating point
2717
2719
// types only.
@@ -2737,81 +2739,46 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
2737
2739
case Builtin::BI__builtin_elementwise_tan:
2738
2740
case Builtin::BI__builtin_elementwise_tanh:
2739
2741
case Builtin::BI__builtin_elementwise_trunc:
2740
- case Builtin::BI__builtin_elementwise_canonicalize: {
2741
- if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2742
- return ExprError();
2743
-
2744
- QualType ArgTy = TheCall->getArg(0)->getType();
2745
- if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(),
2746
- ArgTy, 1))
2742
+ case Builtin::BI__builtin_elementwise_canonicalize:
2743
+ if (PrepareBuiltinElementwiseMathOneArgCall(
2744
+ TheCall, EltwiseBuiltinArgTyRestriction::FloatTy))
2747
2745
return ExprError();
2748
2746
break;
2749
- }
2750
- case Builtin::BI__builtin_elementwise_fma: {
2747
+ case Builtin::BI__builtin_elementwise_fma:
2751
2748
if (BuiltinElementwiseTernaryMath(TheCall))
2752
2749
return ExprError();
2753
2750
break;
2754
- }
2755
2751
2756
2752
// These builtins restrict the element type to floating point
2757
2753
// types only, and take in two arguments.
2758
2754
case Builtin::BI__builtin_elementwise_minimum:
2759
2755
case Builtin::BI__builtin_elementwise_maximum:
2760
2756
case Builtin::BI__builtin_elementwise_atan2:
2761
2757
case Builtin::BI__builtin_elementwise_fmod:
2762
- case Builtin::BI__builtin_elementwise_pow: {
2763
- if (BuiltinElementwiseMath(TheCall, /*FPOnly=*/true))
2758
+ case Builtin::BI__builtin_elementwise_pow:
2759
+ if (BuiltinElementwiseMath(TheCall,
2760
+ EltwiseBuiltinArgTyRestriction::FloatTy))
2764
2761
return ExprError();
2765
2762
break;
2766
- }
2767
-
2768
2763
// These builtins restrict the element type to integer
2769
2764
// types only.
2770
2765
case Builtin::BI__builtin_elementwise_add_sat:
2771
- case Builtin::BI__builtin_elementwise_sub_sat: {
2772
- if (BuiltinElementwiseMath(TheCall))
2773
- return ExprError();
2774
-
2775
- const Expr *Arg = TheCall->getArg(0);
2776
- QualType ArgTy = Arg->getType();
2777
- QualType EltTy = ArgTy;
2778
-
2779
- if (auto *VecTy = EltTy->getAs<VectorType>())
2780
- EltTy = VecTy->getElementType();
2781
-
2782
- if (!EltTy->isIntegerType()) {
2783
- Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
2784
- << 1 << /* integer ty */ 6 << ArgTy;
2766
+ case Builtin::BI__builtin_elementwise_sub_sat:
2767
+ if (BuiltinElementwiseMath(TheCall,
2768
+ EltwiseBuiltinArgTyRestriction::IntegerTy))
2785
2769
return ExprError();
2786
- }
2787
2770
break;
2788
- }
2789
-
2790
2771
case Builtin::BI__builtin_elementwise_min:
2791
2772
case Builtin::BI__builtin_elementwise_max:
2792
2773
if (BuiltinElementwiseMath(TheCall))
2793
2774
return ExprError();
2794
2775
break;
2795
2776
case Builtin::BI__builtin_elementwise_popcount:
2796
- case Builtin::BI__builtin_elementwise_bitreverse: {
2797
- if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
2798
- return ExprError();
2799
-
2800
- const Expr *Arg = TheCall->getArg(0);
2801
- QualType ArgTy = Arg->getType();
2802
- QualType EltTy = ArgTy;
2803
-
2804
- if (auto *VecTy = EltTy->getAs<VectorType>())
2805
- EltTy = VecTy->getElementType();
2806
-
2807
- if (!EltTy->isIntegerType()) {
2808
- Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
2809
- << 1 << /* integer ty */ 6 << ArgTy;
2777
+ case Builtin::BI__builtin_elementwise_bitreverse:
2778
+ if (PrepareBuiltinElementwiseMathOneArgCall(
2779
+ TheCall, EltwiseBuiltinArgTyRestriction::IntegerTy))
2810
2780
return ExprError();
2811
- }
2812
2781
break;
2813
- }
2814
-
2815
2782
case Builtin::BI__builtin_elementwise_copysign: {
2816
2783
if (checkArgCount(TheCall, 2))
2817
2784
return ExprError();
@@ -2823,10 +2790,12 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
2823
2790
2824
2791
QualType MagnitudeTy = Magnitude.get()->getType();
2825
2792
QualType SignTy = Sign.get()->getType();
2826
- if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(),
2827
- MagnitudeTy, 1) ||
2828
- checkFPMathBuiltinElementType(*this, TheCall->getArg(1)->getBeginLoc(),
2829
- SignTy, 2)) {
2793
+ if (checkMathBuiltinElementType(
2794
+ *this, TheCall->getArg(0)->getBeginLoc(), MagnitudeTy,
2795
+ EltwiseBuiltinArgTyRestriction::FloatTy, 1) ||
2796
+ checkMathBuiltinElementType(
2797
+ *this, TheCall->getArg(1)->getBeginLoc(), SignTy,
2798
+ EltwiseBuiltinArgTyRestriction::FloatTy, 2)) {
2830
2799
return ExprError();
2831
2800
}
2832
2801
@@ -14666,7 +14635,8 @@ static ExprResult BuiltinVectorMathConversions(Sema &S, Expr *E) {
14666
14635
return S.UsualUnaryFPConversions(Res.get());
14667
14636
}
14668
14637
14669
- bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
14638
+ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(
14639
+ CallExpr *TheCall, EltwiseBuiltinArgTyRestriction ArgTyRestr) {
14670
14640
if (checkArgCount(TheCall, 1))
14671
14641
return true;
14672
14642
@@ -14677,15 +14647,17 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
14677
14647
TheCall->setArg(0, A.get());
14678
14648
QualType TyA = A.get()->getType();
14679
14649
14680
- if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
14650
+ if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA,
14651
+ ArgTyRestr, 1))
14681
14652
return true;
14682
14653
14683
14654
TheCall->setType(TyA);
14684
14655
return false;
14685
14656
}
14686
14657
14687
- bool Sema::BuiltinElementwiseMath(CallExpr *TheCall, bool FPOnly) {
14688
- if (auto Res = BuiltinVectorMath(TheCall, FPOnly); Res.has_value()) {
14658
+ bool Sema::BuiltinElementwiseMath(CallExpr *TheCall,
14659
+ EltwiseBuiltinArgTyRestriction ArgTyRestr) {
14660
+ if (auto Res = BuiltinVectorMath(TheCall, ArgTyRestr); Res.has_value()) {
14689
14661
TheCall->setType(*Res);
14690
14662
return false;
14691
14663
}
@@ -14718,8 +14690,9 @@ static bool checkBuiltinVectorMathMixedEnums(Sema &S, Expr *LHS, Expr *RHS,
14718
14690
return false;
14719
14691
}
14720
14692
14721
- std::optional<QualType> Sema::BuiltinVectorMath(CallExpr *TheCall,
14722
- bool FPOnly) {
14693
+ std::optional<QualType>
14694
+ Sema::BuiltinVectorMath(CallExpr *TheCall,
14695
+ EltwiseBuiltinArgTyRestriction ArgTyRestr) {
14723
14696
if (checkArgCount(TheCall, 2))
14724
14697
return std::nullopt;
14725
14698
@@ -14740,26 +14713,21 @@ std::optional<QualType> Sema::BuiltinVectorMath(CallExpr *TheCall,
14740
14713
QualType TyA = Args[0]->getType();
14741
14714
QualType TyB = Args[1]->getType();
14742
14715
14716
+ if (checkMathBuiltinElementType(*this, LocA, TyA, ArgTyRestr, 1))
14717
+ return std::nullopt;
14718
+
14743
14719
if (TyA.getCanonicalType() != TyB.getCanonicalType()) {
14744
14720
Diag(LocA, diag::err_typecheck_call_different_arg_types) << TyA << TyB;
14745
14721
return std::nullopt;
14746
14722
}
14747
14723
14748
- if (FPOnly) {
14749
- if (checkFPMathBuiltinElementType(*this, LocA, TyA, 1))
14750
- return std::nullopt;
14751
- } else {
14752
- if (checkMathBuiltinElementType(*this, LocA, TyA, 1))
14753
- return std::nullopt;
14754
- }
14755
-
14756
14724
TheCall->setArg(0, Args[0]);
14757
14725
TheCall->setArg(1, Args[1]);
14758
14726
return TyA;
14759
14727
}
14760
14728
14761
- bool Sema::BuiltinElementwiseTernaryMath(CallExpr *TheCall,
14762
- bool CheckForFloatArgs ) {
14729
+ bool Sema::BuiltinElementwiseTernaryMath(
14730
+ CallExpr *TheCall, EltwiseBuiltinArgTyRestriction ArgTyRestr ) {
14763
14731
if (checkArgCount(TheCall, 3))
14764
14732
return true;
14765
14733
@@ -14779,20 +14747,11 @@ bool Sema::BuiltinElementwiseTernaryMath(CallExpr *TheCall,
14779
14747
Args[I] = Converted.get();
14780
14748
}
14781
14749
14782
- if (CheckForFloatArgs) {
14783
- int ArgOrdinal = 1;
14784
- for (Expr *Arg : Args) {
14785
- if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(),
14786
- Arg->getType(), ArgOrdinal++))
14787
- return true;
14788
- }
14789
- } else {
14790
- int ArgOrdinal = 1;
14791
- for (Expr *Arg : Args) {
14792
- if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
14793
- ArgOrdinal++))
14794
- return true;
14795
- }
14750
+ int ArgOrdinal = 1;
14751
+ for (Expr *Arg : Args) {
14752
+ if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
14753
+ ArgTyRestr, ArgOrdinal++))
14754
+ return true;
14796
14755
}
14797
14756
14798
14757
for (int I = 1; I < 3; ++I) {
0 commit comments