Skip to content

Commit 199bda9

Browse files
committed
update SPV to customeTypeChecking
1 parent 2a7720c commit 199bda9

File tree

14 files changed

+390
-309
lines changed

14 files changed

+390
-309
lines changed

clang/include/clang/Basic/BuiltinsSPIRV.td

Lines changed: 0 additions & 39 deletions
This file was deleted.

clang/include/clang/Basic/BuiltinsSPIRVVK.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ include "clang/Basic/BuiltinsSPIRVBase.td"
1111

1212
def reflect : SPIRVBuiltin<"void(...)", [NoThrow, Const]>;
1313
def faceforward : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;
14+
def refract : SPIRVBuiltin<"void(...)", [NoThrow, Const, CustomTypeChecking]>;

clang/include/clang/Sema/Sema.h

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2791,6 +2791,30 @@ class Sema final : public SemaBase {
27912791

27922792
void CheckConstrainedAuto(const AutoType *AutoT, SourceLocation Loc);
27932793

2794+
/// CheckVectorArgs - Check that the arguments of a vector function call
2795+
bool CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck);
2796+
2797+
bool CheckVectorArgs(CallExpr *TheCall);
2798+
2799+
bool CheckAllArgTypesAreCorrect(
2800+
Sema *S, CallExpr *TheCall,
2801+
llvm::ArrayRef<
2802+
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
2803+
Checks);
2804+
bool CheckAllArgTypesAreCorrect(
2805+
Sema *S, CallExpr *TheCall,
2806+
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check);
2807+
2808+
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
2809+
int ArgOrdinal,
2810+
clang::QualType PassedType);
2811+
static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
2812+
int ArgOrdinal,
2813+
clang::QualType PassedType);
2814+
2815+
static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
2816+
int ArgOrdinal,
2817+
clang::QualType PassedType);
27942818
/// BuiltinConstantArg - Handle a check if argument ArgNum of CallExpr
27952819
/// TheCall is a constant expression.
27962820
bool BuiltinConstantArg(CallExpr *TheCall, int ArgNum, llvm::APSInt &Result);

clang/lib/Headers/hlsl/hlsl_detail.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ template <typename T> struct is_arithmetic {
4545
static const bool Value = __is_arithmetic(T);
4646
};
4747

48+
template <typename T> struct is_vector {
49+
static const bool value = false;
50+
};
51+
52+
template <typename T, int N> struct is_vector<vector<T, N>> {
53+
static const bool value = true;
54+
};
55+
4856
template <typename T, int N>
4957
using HLSL_FIXED_VECTOR =
5058
vector<__detail::enable_if_t<(N > 1 && N <= 4), T>, N>;

clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,22 +72,41 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
7272
}
7373

7474
template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
75-
T K = 1 - Eta * Eta * (1 - (N * I * N * I));
76-
T Result = (Eta * I - (Eta * N * I + sqrt(K)) * N);
75+
T Mul = N * I;
76+
T K = 1 - Eta * Eta * (1 - (Mul * Mul));
77+
T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
7778
return select<T>(K < 0, static_cast<T>(0), Result);
7879
}
7980

81+
template <typename T, typename U>
82+
constexpr T refract_vec_impl(T I, T N, U Eta) {
83+
#if (__has_builtin(__builtin_spirv_refract))
84+
if (is_vector<T>::value) {
85+
return __builtin_spirv_refract(I, N, Eta);
86+
}
87+
#else
88+
T Mul = dot(N, I);
89+
T K = 1 - Eta * Eta * (1 - Mul * Mul);
90+
T Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
91+
return select<T>(K < 0, static_cast<T>(0), Result);
92+
#endif
93+
}
94+
95+
/*
8096
template <typename T, int L>
8197
constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
82-
#if (__has_builtin(__builtin_spirv_refract))
98+
#if (__has_builtin(__builtin_spirv_refract) && is_vector<T>))
8399
return __builtin_spirv_refract(I, N, Eta);
84100
#else
85-
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
86-
vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
101+
T Mul = dot(N, I);
102+
vector<T, L> K = 1 - Eta * Eta * (1 - Mul * Mul);
103+
vector<T, L> Result = (Eta * I - (Eta * Mul + sqrt(K)) * N);
87104
return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
88105
#endif
89106
}
90107
108+
*/
109+
91110
template <typename T> constexpr T fmod_impl(T X, T Y) {
92111
#if !defined(__DIRECTX__)
93112
return __builtin_elementwise_fmod(X, Y);

clang/lib/Sema/SemaChecking.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16151,3 +16151,108 @@ void Sema::CheckTCBEnforcement(const SourceLocation CallExprLoc,
1615116151
}
1615216152
}
1615316153
}
16154+
16155+
bool Sema::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
16156+
for (unsigned i = 0; i < NumArgsToCheck; ++i) {
16157+
ExprResult Arg = TheCall->getArg(i);
16158+
QualType ArgTy = Arg.get()->getType();
16159+
auto *VTy = ArgTy->getAs<VectorType>();
16160+
if (VTy == nullptr) {
16161+
SemaRef.Diag(Arg.get()->getBeginLoc(),
16162+
diag::err_typecheck_convert_incompatible)
16163+
<< ArgTy
16164+
<< SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
16165+
<< 0 << 0;
16166+
return true;
16167+
}
16168+
}
16169+
return false;
16170+
}
16171+
16172+
bool Sema::CheckVectorArgs(CallExpr *TheCall) {
16173+
return CheckVectorArgs(TheCall, TheCall->getNumArgs());
16174+
}
16175+
16176+
16177+
bool Sema::CheckAllArgTypesAreCorrect(
16178+
Sema *S, CallExpr *TheCall,
16179+
llvm::ArrayRef<
16180+
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
16181+
Checks) {
16182+
unsigned NumArgs = TheCall->getNumArgs();
16183+
if (Checks.size() == 1) {
16184+
// Apply the single check to all arguments
16185+
for (unsigned I = 0; I < NumArgs; ++I) {
16186+
Expr *Arg = TheCall->getArg(I);
16187+
if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
16188+
return true;
16189+
}
16190+
return false;
16191+
} else if (Checks.size() == NumArgs) {
16192+
// Apply each check to the corresponding argument
16193+
for (unsigned I = 0; I < NumArgs; ++I) {
16194+
Expr *Arg = TheCall->getArg(I);
16195+
if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
16196+
return true;
16197+
}
16198+
return false;
16199+
} else {
16200+
// Mismatch: error or fallback
16201+
S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
16202+
<< NumArgs << Checks.size();
16203+
return true;
16204+
}
16205+
}
16206+
16207+
bool Sema::CheckAllArgTypesAreCorrect(
16208+
Sema *S, CallExpr *TheCall,
16209+
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
16210+
return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
16211+
}
16212+
16213+
bool Sema::CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
16214+
int ArgOrdinal,
16215+
clang::QualType PassedType) {
16216+
clang::QualType BaseType =
16217+
PassedType->isVectorType()
16218+
? PassedType->castAs<clang::VectorType>()->getElementType()
16219+
: PassedType;
16220+
if (!BaseType->isHalfType() && !BaseType->isFloat32Type())
16221+
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
16222+
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
16223+
<< /* half or float */ 2 << PassedType;
16224+
return false;
16225+
}
16226+
16227+
bool Sema::CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
16228+
int ArgOrdinal,
16229+
clang::QualType PassedType) {
16230+
const auto *VecTy = PassedType->getAs<VectorType>();
16231+
16232+
clang::QualType BaseType =
16233+
PassedType->isVectorType()
16234+
? PassedType->castAs<clang::VectorType>()->getElementType()
16235+
: PassedType;
16236+
if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
16237+
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
16238+
<< ArgOrdinal << /* vector of */ 5 << /* no int */ 0
16239+
<< /* half or float */ 2 << PassedType;
16240+
return false;
16241+
}
16242+
16243+
bool Sema::CheckFloatOrHalfScalarRepresentation(
16244+
Sema *S, SourceLocation Loc,
16245+
int ArgOrdinal,
16246+
clang::QualType PassedType) {
16247+
const auto *VecTy = PassedType->getAs<VectorType>();
16248+
16249+
clang::QualType BaseType =
16250+
PassedType->isVectorType()
16251+
? PassedType->castAs<clang::VectorType>()->getElementType()
16252+
: PassedType;
16253+
if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
16254+
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
16255+
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
16256+
<< /* half or float */ 2 << PassedType;
16257+
return false;
16258+
}

clang/lib/Sema/SemaHLSL.cpp

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2401,17 +2401,40 @@ static bool CheckArgTypeMatches(Sema *S, Expr *Arg, QualType ExpectedType) {
24012401
return false;
24022402
}
24032403

2404-
static bool CheckAllArgTypesAreCorrect(
2404+
bool CheckAllArgTypesAreCorrect(
24052405
Sema *S, CallExpr *TheCall,
2406-
llvm::function_ref<bool(Sema *S, SourceLocation Loc, int ArgOrdinal,
2407-
clang::QualType PassedType)>
2408-
Check) {
2409-
for (unsigned I = 0; I < TheCall->getNumArgs(); ++I) {
2410-
Expr *Arg = TheCall->getArg(I);
2411-
if (Check(S, Arg->getBeginLoc(), I + 1, Arg->getType()))
2412-
return true;
2406+
llvm::ArrayRef<
2407+
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)>>
2408+
Checks) {
2409+
unsigned NumArgs = TheCall->getNumArgs();
2410+
if (Checks.size() == 1) {
2411+
// Apply the single check to all arguments
2412+
for (unsigned I = 0; I < NumArgs; ++I) {
2413+
Expr *Arg = TheCall->getArg(I);
2414+
if (Checks[0](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
2415+
return true;
2416+
}
2417+
return false;
2418+
} else if (Checks.size() == NumArgs) {
2419+
// Apply each check to the corresponding argument
2420+
for (unsigned I = 0; I < NumArgs; ++I) {
2421+
Expr *Arg = TheCall->getArg(I);
2422+
if (Checks[I](S, Arg->getBeginLoc(), I + 1, Arg->getType()))
2423+
return true;
2424+
}
2425+
return false;
2426+
} else {
2427+
// Mismatch: error or fallback
2428+
S->Diag(TheCall->getBeginLoc(), diag::err_builtin_invalid_arg_type)
2429+
<< NumArgs << Checks.size();
2430+
return true;
24132431
}
2414-
return false;
2432+
}
2433+
2434+
bool CheckAllArgTypesAreCorrect(
2435+
Sema *S, CallExpr *TheCall,
2436+
llvm::function_ref<bool(Sema *, SourceLocation, int, QualType)> Check) {
2437+
return CheckAllArgTypesAreCorrect(S, TheCall, llvm::ArrayRef{Check});
24152438
}
24162439

24172440
static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
@@ -2428,6 +2451,38 @@ static bool CheckFloatOrHalfRepresentation(Sema *S, SourceLocation Loc,
24282451
return false;
24292452
}
24302453

2454+
static bool CheckFloatOrHalfVectorsRepresentation(Sema *S, SourceLocation Loc,
2455+
int ArgOrdinal,
2456+
clang::QualType PassedType) {
2457+
const auto *VecTy = PassedType->getAs<VectorType>();
2458+
2459+
clang::QualType BaseType =
2460+
PassedType->isVectorType()
2461+
? PassedType->castAs<clang::VectorType>()->getElementType()
2462+
: PassedType;
2463+
if (!VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
2464+
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2465+
<< ArgOrdinal << /* vector of */ 5 << /* no int */ 0
2466+
<< /* half or float */ 2 << PassedType;
2467+
return false;
2468+
}
2469+
2470+
static bool CheckFloatOrHalfScalarRepresentation(Sema *S, SourceLocation Loc,
2471+
int ArgOrdinal,
2472+
clang::QualType PassedType) {
2473+
const auto *VecTy = PassedType->getAs<VectorType>();
2474+
2475+
clang::QualType BaseType =
2476+
PassedType->isVectorType()
2477+
? PassedType->castAs<clang::VectorType>()->getElementType()
2478+
: PassedType;
2479+
if (VecTy || !BaseType->isHalfType() && !BaseType->isFloat32Type())
2480+
return S->Diag(Loc, diag::err_builtin_invalid_arg_type)
2481+
<< ArgOrdinal << /* scalar or vector of */ 5 << /* no int */ 0
2482+
<< /* half or float */ 2 << PassedType;
2483+
return false;
2484+
}
2485+
24312486
static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall,
24322487
unsigned ArgIndex) {
24332488
auto *Arg = TheCall->getArg(ArgIndex);

0 commit comments

Comments
 (0)