diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index 472a0e25adc97..a797d8f7d37dc 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -15216,6 +15216,30 @@ class Sema final : public SemaBase { void performFunctionEffectAnalysis(TranslationUnitDecl *TU); ///@} + + // + // + // ------------------------------------------------------------------------- + // + // + + /// \name Common Helper Functions + /// Implementations are in Common.cpp + ///@{ +public: + static bool CheckArgTypeIsCorrect( + Sema *S, Expr *Arg, QualType ExpectedType, + llvm::function_ref Check); + + static bool CheckAllArgTypesAreCorrect( + Sema *S, CallExpr *TheCall, QualType ExpectedType, + llvm::function_ref Check); + + static bool CheckAllArgTypesAreCorrect(Sema *SemaPtr, CallExpr *TheCall, + unsigned int NumOfElts, + unsigned int expectedNumOfElts); + + ///@} }; DeductionFailureInfo diff --git a/clang/lib/Sema/CMakeLists.txt b/clang/lib/Sema/CMakeLists.txt index 19cf3a2db00fd..ddc340a51a3b2 100644 --- a/clang/lib/Sema/CMakeLists.txt +++ b/clang/lib/Sema/CMakeLists.txt @@ -17,6 +17,7 @@ add_clang_library(clangSema AnalysisBasedWarnings.cpp CheckExprLifetime.cpp CodeCompleteConsumer.cpp + Common.cpp DeclSpec.cpp DelayedDiagnostic.cpp HeuristicResolver.cpp diff --git a/clang/lib/Sema/Common.cpp b/clang/lib/Sema/Common.cpp new file mode 100644 index 0000000000000..2a2f4402e84aa --- /dev/null +++ b/clang/lib/Sema/Common.cpp @@ -0,0 +1,68 @@ +//===--- Common.cpp --- Semantic Analysis common implementation file ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This implements common functions used in SPIRV and HLSL semantic +// analysis constructs. +//===----------------------------------------------------------------------===// + +#include "clang/Sema/Sema.h" + +bool clang::Sema::CheckArgTypeIsCorrect( + Sema *S, Expr *Arg, QualType ExpectedType, + llvm::function_ref Check) { + QualType PassedType = Arg->getType(); + if (Check(PassedType)) { + if (auto *VecTyA = PassedType->getAs()) + ExpectedType = S->Context.getVectorType( + ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind()); + S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible) + << PassedType << ExpectedType << 1 << 0 << 0; + return true; + } + return false; +} + +bool clang::Sema::CheckAllArgTypesAreCorrect( + Sema *S, CallExpr *TheCall, QualType ExpectedType, + llvm::function_ref Check) { + for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) { + Expr *Arg = TheCall->getArg(i); + if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) { + return true; + } + } + return false; +} + +bool clang::Sema::CheckAllArgTypesAreCorrect(Sema *SemaPtr, CallExpr *TheCall, + unsigned int NumOfElts, + unsigned int expectedNumOfElts) { + if (SemaPtr->checkArgCount(TheCall, NumOfElts)) { + return true; + } + + for (unsigned i = 0; i < NumOfElts; i++) { + Expr *localArg = TheCall->getArg(i); + QualType PassedType = localArg->getType(); + QualType ExpectedType = SemaPtr->Context.getVectorType( + PassedType, expectedNumOfElts, VectorKind::Generic); + auto Check = [](QualType PassedType) { + return PassedType->getAs() == nullptr; + }; + + if (CheckArgTypeIsCorrect(SemaPtr, localArg, ExpectedType, Check)) { + return true; + } + } + + if (auto *localArgVecTy = + TheCall->getArg(0)->getType()->getAs()) { + TheCall->setType(localArgVecTy->getElementType()); + } + + return false; +} diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index d748c10455289..3bfdd195ac7f1 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1996,39 +1996,12 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) { return false; } -static bool CheckArgTypeIsCorrect( - Sema *S, Expr *Arg, QualType ExpectedType, - llvm::function_ref Check) { - QualType PassedType = Arg->getType(); - if (Check(PassedType)) { - if (auto *VecTyA = PassedType->getAs()) - ExpectedType = S->Context.getVectorType( - ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind()); - S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible) - << PassedType << ExpectedType << 1 << 0 << 0; - return true; - } - return false; -} - -static bool CheckAllArgTypesAreCorrect( - Sema *S, CallExpr *TheCall, QualType ExpectedType, - llvm::function_ref Check) { - for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) { - Expr *Arg = TheCall->getArg(i); - if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) { - return true; - } - } - return false; -} - static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) { auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool { return !PassedType->hasFloatingRepresentation(); }; - return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy, - checkAllFloatTypes); + return clang::Sema::CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy, + checkAllFloatTypes); } static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) { @@ -2039,8 +2012,8 @@ static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) { : PassedType; return !BaseType->isHalfType() && !BaseType->isFloat32Type(); }; - return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy, - checkFloatorHalf); + return clang::Sema::CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy, + checkFloatorHalf); } static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall, @@ -2060,24 +2033,24 @@ static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) { return VecTy->getElementType()->isDoubleType(); return false; }; - return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy, - checkDoubleVector); + return clang::Sema::CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy, + checkDoubleVector); } static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) { auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool { return !PassedType->hasIntegerRepresentation() && !PassedType->hasFloatingRepresentation(); }; - return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy, - checkAllSignedTypes); + return clang::Sema::CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy, + checkAllSignedTypes); } static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) { auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool { return !PassedType->hasUnsignedIntegerRepresentation(); }; - return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy, - checkAllUnsignedTypes); + return clang::Sema::CheckAllArgTypesAreCorrect( + S, TheCall, S->Context.UnsignedIntTy, checkAllUnsignedTypes); } static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, diff --git a/clang/lib/Sema/SemaSPIRV.cpp b/clang/lib/Sema/SemaSPIRV.cpp index dc49fc7907357..dc091d3454e7f 100644 --- a/clang/lib/Sema/SemaSPIRV.cpp +++ b/clang/lib/Sema/SemaSPIRV.cpp @@ -20,54 +20,10 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { switch (BuiltinID) { case SPIRV::BI__builtin_spirv_distance: { - if (SemaRef.checkArgCount(TheCall, 2)) - return true; - - ExprResult A = TheCall->getArg(0); - QualType ArgTyA = A.get()->getType(); - auto *VTyA = ArgTyA->getAs(); - if (VTyA == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyA - << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 - << 0 << 0; - return true; - } - - ExprResult B = TheCall->getArg(1); - QualType ArgTyB = B.get()->getType(); - auto *VTyB = ArgTyB->getAs(); - if (VTyB == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyB - << SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1 - << 0 << 0; - return true; - } - - QualType RetTy = VTyA->getElementType(); - TheCall->setType(RetTy); - break; + return clang::Sema::CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 2, 2); } case SPIRV::BI__builtin_spirv_length: { - if (SemaRef.checkArgCount(TheCall, 1)) - return true; - ExprResult A = TheCall->getArg(0); - QualType ArgTyA = A.get()->getType(); - auto *VTy = ArgTyA->getAs(); - if (VTy == nullptr) { - SemaRef.Diag(A.get()->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << ArgTyA - << SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1 - << 0 << 0; - return true; - } - QualType RetTy = VTy->getElementType(); - TheCall->setType(RetTy); - break; + return clang::Sema::CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 1, 2); } } return false;