Skip to content

[clang:frontend] Move helper functions to common location for SemaSPIRV #125045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions clang/include/clang/Sema/Sema.h
Original file line number Diff line number Diff line change
Expand Up @@ -15216,6 +15216,30 @@ class Sema final : public SemaBase {
void performFunctionEffectAnalysis(TranslationUnitDecl *TU);

///@}

//
//
// -------------------------------------------------------------------------
//
//

/// \name Common Helper Functions
/// Implementations are in Common.cpp
///@{
public:
static bool CheckArgTypeIsCorrect(
Sema *S, Expr *Arg, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check);

static bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check);

static bool CheckAllArgTypesAreCorrect(Sema *SemaPtr, CallExpr *TheCall,
unsigned int NumOfElts,
unsigned int expectedNumOfElts);

///@}
};

DeductionFailureInfo
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Sema/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_clang_library(clangSema
AnalysisBasedWarnings.cpp
CheckExprLifetime.cpp
CodeCompleteConsumer.cpp
Common.cpp
DeclSpec.cpp
DelayedDiagnostic.cpp
HeuristicResolver.cpp
Expand Down
68 changes: 68 additions & 0 deletions clang/lib/Sema/Common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
//===--- Common.cpp --- Semantic Analysis common implementation file ------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
// This implements common functions used in SPIRV and HLSL semantic
// analysis constructs.
//===----------------------------------------------------------------------===//

#include "clang/Sema/Sema.h"

bool clang::Sema::CheckArgTypeIsCorrect(
Sema *S, Expr *Arg, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
QualType PassedType = Arg->getType();
if (Check(PassedType)) {
if (auto *VecTyA = PassedType->getAs<VectorType>())
ExpectedType = S->Context.getVectorType(
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
<< PassedType << ExpectedType << 1 << 0 << 0;
return true;
}
return false;
}

bool clang::Sema::CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
Expr *Arg = TheCall->getArg(i);
if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
}
}
return false;
}

bool clang::Sema::CheckAllArgTypesAreCorrect(Sema *SemaPtr, CallExpr *TheCall,
unsigned int NumOfElts,
unsigned int expectedNumOfElts) {
if (SemaPtr->checkArgCount(TheCall, NumOfElts)) {
return true;
}

for (unsigned i = 0; i < NumOfElts; i++) {
Expr *localArg = TheCall->getArg(i);
QualType PassedType = localArg->getType();
QualType ExpectedType = SemaPtr->Context.getVectorType(
PassedType, expectedNumOfElts, VectorKind::Generic);
auto Check = [](QualType PassedType) {
return PassedType->getAs<VectorType>() == nullptr;
};

if (CheckArgTypeIsCorrect(SemaPtr, localArg, ExpectedType, Check)) {
return true;
}
}

if (auto *localArgVecTy =
TheCall->getArg(0)->getType()->getAs<VectorType>()) {
TheCall->setType(localArgVecTy->getElementType());
}

return false;
}
47 changes: 10 additions & 37 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1996,39 +1996,12 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
return false;
}

static bool CheckArgTypeIsCorrect(
Sema *S, Expr *Arg, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
QualType PassedType = Arg->getType();
if (Check(PassedType)) {
if (auto *VecTyA = PassedType->getAs<VectorType>())
ExpectedType = S->Context.getVectorType(
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
<< PassedType << ExpectedType << 1 << 0 << 0;
return true;
}
return false;
}

static bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
Expr *Arg = TheCall->getArg(i);
if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
}
}
return false;
}

static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasFloatingRepresentation();
};
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkAllFloatTypes);
return clang::Sema::CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkAllFloatTypes);
}

static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
Expand All @@ -2039,8 +2012,8 @@ static bool CheckFloatOrHalfRepresentations(Sema *S, CallExpr *TheCall) {
: PassedType;
return !BaseType->isHalfType() && !BaseType->isFloat32Type();
};
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkFloatorHalf);
return clang::Sema::CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkFloatorHalf);
}

static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
Expand All @@ -2060,24 +2033,24 @@ static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) {
return VecTy->getElementType()->isDoubleType();
return false;
};
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkDoubleVector);
return clang::Sema::CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy,
checkDoubleVector);
}
static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasIntegerRepresentation() &&
!PassedType->hasFloatingRepresentation();
};
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
checkAllSignedTypes);
return clang::Sema::CheckAllArgTypesAreCorrect(S, TheCall, S->Context.IntTy,
checkAllSignedTypes);
}

static bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllUnsignedTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasUnsignedIntegerRepresentation();
};
return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.UnsignedIntTy,
checkAllUnsignedTypes);
return clang::Sema::CheckAllArgTypesAreCorrect(
S, TheCall, S->Context.UnsignedIntTy, checkAllUnsignedTypes);
}

static void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
Expand Down
48 changes: 2 additions & 46 deletions clang/lib/Sema/SemaSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,54 +20,10 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
CallExpr *TheCall) {
switch (BuiltinID) {
case SPIRV::BI__builtin_spirv_distance: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;

ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
auto *VTyA = ArgTyA->getAs<VectorType>();
if (VTyA == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}

ExprResult B = TheCall->getArg(1);
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}

QualType RetTy = VTyA->getElementType();
TheCall->setType(RetTy);
break;
return clang::Sema::CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 2, 2);
}
case SPIRV::BI__builtin_spirv_length: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;
ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
auto *VTy = ArgTyA->getAs<VectorType>();
if (VTy == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}
QualType RetTy = VTy->getElementType();
TheCall->setType(RetTy);
break;
return clang::Sema::CheckAllArgTypesAreCorrect(&SemaRef, TheCall, 1, 2);
}
}
return false;
Expand Down