Skip to content

Commit d938f75

Browse files
committed
[clang:frontend] Move helper functions in SemaHLSL to common location for SemaSPIRV
1 parent 62f6d63 commit d938f75

File tree

5 files changed

+95
-73
lines changed

5 files changed

+95
-73
lines changed

clang/include/clang/Sema/Common.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef LLVM_CLANG_SEMA_COMMON_H
2+
#define LLVM_CLANG_SEMA_COMMON_H
3+
4+
#include "clang/Sema/Sema.h"
5+
6+
namespace clang {
7+
8+
using LLVMFnRef = llvm::function_ref<bool(clang::QualType PassedType)>;
9+
using PairParam = std::pair<unsigned int, unsigned int>;
10+
using CheckParam = std::variant<PairParam, LLVMFnRef>;
11+
12+
bool CheckArgTypeIsCorrect(
13+
Sema *S, Expr *Arg, QualType ExpectedType,
14+
llvm::function_ref<bool(clang::QualType PassedType)> Check);
15+
16+
bool CheckAllArgTypesAreCorrect(
17+
Sema *SemaPtr, CallExpr *TheCall,
18+
std::variant<QualType, std::nullopt_t> ExpectedType, CheckParam Check);
19+
20+
} // namespace clang
21+
22+
#endif

clang/lib/Sema/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_clang_library(clangSema
1717
AnalysisBasedWarnings.cpp
1818
CheckExprLifetime.cpp
1919
CodeCompleteConsumer.cpp
20+
Common.cpp
2021
DeclSpec.cpp
2122
DelayedDiagnostic.cpp
2223
HeuristicResolver.cpp

clang/lib/Sema/Common.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#include "clang/Sema/Common.h"
2+
3+
namespace clang {
4+
5+
bool CheckArgTypeIsCorrect(
6+
Sema *S, Expr *Arg, QualType ExpectedType,
7+
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
8+
QualType PassedType = Arg->getType();
9+
if (Check(PassedType)) {
10+
if (auto *VecTyA = PassedType->getAs<VectorType>())
11+
ExpectedType = S->Context.getVectorType(
12+
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
13+
S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
14+
<< PassedType << ExpectedType << 1 << 0 << 0;
15+
return true;
16+
}
17+
return false;
18+
}
19+
20+
bool CheckAllArgTypesAreCorrect(
21+
Sema *SemaPtr, CallExpr *TheCall,
22+
std::variant<QualType, std::nullopt_t> ExpectedType, CheckParam Check) {
23+
unsigned int NumElts;
24+
unsigned int expected;
25+
if (auto *n = std::get_if<PairParam>(&Check)) {
26+
if (SemaPtr->checkArgCount(TheCall, n->first)) {
27+
return true;
28+
}
29+
NumElts = n->first;
30+
expected = n->second;
31+
} else {
32+
NumElts = TheCall->getNumArgs();
33+
}
34+
35+
for (unsigned i = 0; i < NumElts; i++) {
36+
Expr *localArg = TheCall->getArg(i);
37+
if (auto *val = std::get_if<QualType>(&ExpectedType)) {
38+
if (auto *fn = std::get_if<LLVMFnRef>(&Check)) {
39+
return CheckArgTypeIsCorrect(SemaPtr, localArg, *val, *fn);
40+
}
41+
}
42+
43+
QualType PassedType = localArg->getType();
44+
if (PassedType->getAs<VectorType>() == nullptr) {
45+
SemaPtr->Diag(localArg->getBeginLoc(),
46+
diag::err_typecheck_convert_incompatible)
47+
<< PassedType
48+
<< SemaPtr->Context.getVectorType(PassedType, expected,
49+
VectorKind::Generic)
50+
<< 1 << 0 << 0;
51+
return true;
52+
}
53+
}
54+
55+
if (std::get_if<PairParam>(&Check)) {
56+
if (auto *localArgVecTy =
57+
TheCall->getArg(0)->getType()->getAs<VectorType>()) {
58+
TheCall->setType(localArgVecTy->getElementType());
59+
}
60+
}
61+
62+
return false;
63+
}
64+
65+
} // namespace clang

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "clang/Basic/SourceLocation.h"
2828
#include "clang/Basic/Specifiers.h"
2929
#include "clang/Basic/TargetInfo.h"
30+
#include "clang/Sema/Common.h"
3031
#include "clang/Sema/Initialization.h"
3132
#include "clang/Sema/ParsedAttr.h"
3233
#include "clang/Sema/Sema.h"
@@ -1996,33 +1997,6 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
19961997
return false;
19971998
}
19981999

1999-
static bool CheckArgTypeIsCorrect(
2000-
Sema *S, Expr *Arg, QualType ExpectedType,
2001-
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
2002-
QualType PassedType = Arg->getType();
2003-
if (Check(PassedType)) {
2004-
if (auto *VecTyA = PassedType->getAs<VectorType>())
2005-
ExpectedType = S->Context.getVectorType(
2006-
ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind());
2007-
S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible)
2008-
<< PassedType << ExpectedType << 1 << 0 << 0;
2009-
return true;
2010-
}
2011-
return false;
2012-
}
2013-
2014-
static bool CheckAllArgTypesAreCorrect(
2015-
Sema *S, CallExpr *TheCall, QualType ExpectedType,
2016-
llvm::function_ref<bool(clang::QualType PassedType)> Check) {
2017-
for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) {
2018-
Expr *Arg = TheCall->getArg(i);
2019-
if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) {
2020-
return true;
2021-
}
2022-
}
2023-
return false;
2024-
}
2025-
20262000
static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) {
20272001
auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool {
20282002
return !PassedType->hasFloatingRepresentation();

clang/lib/Sema/SemaSPIRV.cpp

Lines changed: 6 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010

1111
#include "clang/Sema/SemaSPIRV.h"
1212
#include "clang/Basic/TargetBuiltins.h"
13+
#include "clang/Sema/Common.h"
1314
#include "clang/Sema/Sema.h"
15+
#include <utility>
1416

1517
namespace clang {
1618

@@ -20,54 +22,12 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
2022
CallExpr *TheCall) {
2123
switch (BuiltinID) {
2224
case SPIRV::BI__builtin_spirv_distance: {
23-
if (SemaRef.checkArgCount(TheCall, 2))
24-
return true;
25-
26-
ExprResult A = TheCall->getArg(0);
27-
QualType ArgTyA = A.get()->getType();
28-
auto *VTyA = ArgTyA->getAs<VectorType>();
29-
if (VTyA == nullptr) {
30-
SemaRef.Diag(A.get()->getBeginLoc(),
31-
diag::err_typecheck_convert_incompatible)
32-
<< ArgTyA
33-
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
34-
<< 0 << 0;
35-
return true;
36-
}
37-
38-
ExprResult B = TheCall->getArg(1);
39-
QualType ArgTyB = B.get()->getType();
40-
auto *VTyB = ArgTyB->getAs<VectorType>();
41-
if (VTyB == nullptr) {
42-
SemaRef.Diag(A.get()->getBeginLoc(),
43-
diag::err_typecheck_convert_incompatible)
44-
<< ArgTyB
45-
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
46-
<< 0 << 0;
47-
return true;
48-
}
49-
50-
QualType RetTy = VTyA->getElementType();
51-
TheCall->setType(RetTy);
52-
break;
25+
return CheckAllArgTypesAreCorrect(&SemaRef, TheCall, std::nullopt,
26+
std::make_pair(2, 2));
5327
}
5428
case SPIRV::BI__builtin_spirv_length: {
55-
if (SemaRef.checkArgCount(TheCall, 1))
56-
return true;
57-
ExprResult A = TheCall->getArg(0);
58-
QualType ArgTyA = A.get()->getType();
59-
auto *VTy = ArgTyA->getAs<VectorType>();
60-
if (VTy == nullptr) {
61-
SemaRef.Diag(A.get()->getBeginLoc(),
62-
diag::err_typecheck_convert_incompatible)
63-
<< ArgTyA
64-
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
65-
<< 0 << 0;
66-
return true;
67-
}
68-
QualType RetTy = VTy->getElementType();
69-
TheCall->setType(RetTy);
70-
break;
29+
return CheckAllArgTypesAreCorrect(&SemaRef, TheCall, std::nullopt,
30+
std::make_pair(1, 2));
7131
}
7232
}
7333
return false;

0 commit comments

Comments
 (0)