Skip to content

[clang:frontend] Move helper functions to common location for SemaSPIRV #125045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions clang/include/clang/Sema/Common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef LLVM_CLANG_SEMA_COMMON_H
#define LLVM_CLANG_SEMA_COMMON_H

#include "clang/Sema/Sema.h"

namespace clang {

using LLVMFnRef = llvm::function_ref<bool(clang::QualType PassedType)>;
using PairParam = std::pair<unsigned int, unsigned int>;
using CheckParam = std::variant<PairParam, LLVMFnRef>;

bool CheckArgTypeIsCorrect(
Sema *S, Expr *Arg, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check);

bool CheckAllArgTypesAreCorrect(
Sema *SemaPtr, CallExpr *TheCall,
std::variant<QualType, std::nullopt_t> ExpectedType, CheckParam Check);

} // namespace clang

#endif
1 change: 1 addition & 0 deletions clang/lib/Sema/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_clang_library(clangSema
AnalysisBasedWarnings.cpp
CheckExprLifetime.cpp
CodeCompleteConsumer.cpp
Common.cpp
DeclSpec.cpp
DelayedDiagnostic.cpp
HeuristicResolver.cpp
Expand Down
65 changes: 65 additions & 0 deletions clang/lib/Sema/Common.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#include "clang/Sema/Common.h"

namespace clang {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://llvm.org/docs/CodingStandards.html#use-namespace-qualifiers-to-implement-previously-declared-functions

For example clang::CheckArgTypeIsCorrect instead of opening the clang namespace and just having CheckArgTypeIsCorrect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


bool CheckArgTypeIsCorrect(
Sema *S, Expr *Arg, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
QualType PassedType = Arg->getType();
if (Check(PassedType)) {
if (auto *VecTyA = PassedType->getAs<VectorType>())
ExpectedType = S->Context.getVectorType(
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
<< PassedType << ExpectedType << 1 << 0 << 0;
return true;
}
return false;
}

bool CheckAllArgTypesAreCorrect(
Sema *SemaPtr, CallExpr *TheCall,
std::variant<QualType, std::nullopt_t> ExpectedType, CheckParam Check) {
unsigned int NumElts;
unsigned int expected;
if (auto *n = std::get_if<PairParam>(&Check)) {
if (SemaPtr->checkArgCount(TheCall, n->first)) {
return true;
}
NumElts = n->first;
expected = n->second;
} else {
NumElts = TheCall->getNumArgs();
}

for (unsigned i = 0; i < NumElts; i++) {
Expr *localArg = TheCall->getArg(i);
if (auto *val = std::get_if<QualType>(&ExpectedType)) {
if (auto *fn = std::get_if<LLVMFnRef>(&Check)) {
return CheckArgTypeIsCorrect(SemaPtr, localArg, *val, *fn);
}
}

QualType PassedType = localArg->getType();
if (PassedType->getAs<VectorType>() == nullptr) {
SemaPtr->Diag(localArg->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< PassedType
<< SemaPtr->Context.getVectorType(PassedType, expected,
VectorKind::Generic)
<< 1 << 0 << 0;
return true;
}
}

if (std::get_if<PairParam>(&Check)) {
if (auto *localArgVecTy =
TheCall->getArg(0)->getType()->getAs<VectorType>()) {
TheCall->setType(localArgVecTy->getElementType());
}
}

return false;
}

} // namespace clang
28 changes: 1 addition & 27 deletions clang/lib/Sema/SemaHLSL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "clang/Basic/SourceLocation.h"
#include "clang/Basic/Specifiers.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/Sema/Common.h"
#include "clang/Sema/Initialization.h"
#include "clang/Sema/ParsedAttr.h"
#include "clang/Sema/Sema.h"
Expand Down Expand Up @@ -1996,33 +1997,6 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
return false;
}

static bool CheckArgTypeIsCorrect(
Sema *S, Expr *Arg, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
QualType PassedType = Arg->getType();
if (Check(PassedType)) {
if (auto *VecTyA = PassedType->getAs<VectorType>())
ExpectedType = S->Context.getVectorType(
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
<< PassedType << ExpectedType << 1 << 0 << 0;
return true;
}
return false;
}

static bool CheckAllArgTypesAreCorrect(
Sema *S, CallExpr *TheCall, QualType ExpectedType,
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
Expr *Arg = TheCall->getArg(i);
if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
return true;
}
}
return false;
}

static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
return !PassedType->hasFloatingRepresentation();
Expand Down
52 changes: 6 additions & 46 deletions clang/lib/Sema/SemaSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

#include "clang/Sema/SemaSPIRV.h"
#include "clang/Basic/TargetBuiltins.h"
#include "clang/Sema/Common.h"
#include "clang/Sema/Sema.h"
#include <utility>

namespace clang {

Expand All @@ -20,54 +22,12 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
CallExpr *TheCall) {
switch (BuiltinID) {
case SPIRV::BI__builtin_spirv_distance: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;

ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
auto *VTyA = ArgTyA->getAs<VectorType>();
if (VTyA == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}

ExprResult B = TheCall->getArg(1);
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}

QualType RetTy = VTyA->getElementType();
TheCall->setType(RetTy);
break;
return CheckAllArgTypesAreCorrect(&SemaRef, TheCall, std::nullopt,
std::make_pair(2, 2));
}
case SPIRV::BI__builtin_spirv_length: {
if (SemaRef.checkArgCount(TheCall, 1))
return true;
ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
auto *VTy = ArgTyA->getAs<VectorType>();
if (VTy == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}
QualType RetTy = VTy->getElementType();
TheCall->setType(RetTy);
break;
return CheckAllArgTypesAreCorrect(&SemaRef, TheCall, std::nullopt,
std::make_pair(1, 2));
}
}
return false;
Expand Down