Skip to content

Commit 390a4c7

Browse files
committed
modularize SemaSPIRV.cpp
1 parent b46a32f commit 390a4c7

File tree

7 files changed

+69
-104
lines changed

7 files changed

+69
-104
lines changed

clang/lib/CodeGen/TargetBuiltins/SPIR.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
6464
Value *eta = EmitScalarExpr(E->getArg(2));
6565
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
6666
E->getArg(1)->getType()->hasFloatingRepresentation() &&
67-
E->getArg(2)->getType()->hasFloatingRepresentation() &&
67+
E->getArg(2)->getType()->isFloatingType() &&
6868
"refract operands must have a float representation");
6969
assert(E->getArg(0)->getType()->isVectorType() &&
7070
E->getArg(1)->getType()->isVectorType() &&

clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
#define _HLSL_HLSL_INTRINSIC_HELPERS_H_
1111

1212
namespace hlsl {
13-
namespace __dETAil {
13+
namespace __detail {
1414

1515
constexpr vector<uint, 4> d3d_color_to_ubyte4_impl(vector<float, 4> V) {
1616
// Use the same scaling factor used by FXC, and DXC for DXIL
@@ -73,10 +73,8 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
7373

7474
template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
7575
T K = 1 - Eta * Eta * (1 - (N * I * N * I));
76-
if (K < 0)
77-
return 0;
78-
else
79-
return (Eta * I - (Eta * N * I + sqrt(K)) * N);
76+
T Result = (Eta * I - (Eta * N * I + sqrt(K)) * N);
77+
return select<T>(K < 0, static_cast<T>(0), Result);
8078
}
8179

8280
template <typename T, int L>
@@ -85,13 +83,12 @@ constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
8583
return __builtin_spirv_refract(I, N, Eta);
8684
#else
8785
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
88-
if (K < 0)
89-
return 0;
90-
else
91-
return (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
86+
vector<T, L> Result = (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
87+
return select<vector<T, L>>(K < 0, vector<T, L>(0), Result);
9288
#endif
9389
}
9490

91+
9592
template <typename T> constexpr T fmod_impl(T X, T Y) {
9693
#if !defined(__DIRECTX__)
9794
return __builtin_elementwise_fmod(X, Y);

clang/lib/Sema/SemaSPIRV.cpp

Lines changed: 42 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,31 @@ namespace clang {
2929

3030
SemaSPIRV::SemaSPIRV(Sema &S) : SemaBase(S) {}
3131

32+
/// Checks if the first `NumArgsToCheck` arguments of a function call are of vector type.
33+
/// If any of the arguments is not a vector type, it emits a diagnostic error and returns `true`.
34+
/// Otherwise, it returns `false`.
35+
///
36+
/// \param TheCall The function call expression to check.
37+
/// \param NumArgsToCheck The number of arguments to check for vector type.
38+
/// \return `true` if any of the arguments is not a vector type, `false` otherwise.
39+
40+
bool SemaSPIRV::CheckVectorArgs(CallExpr *TheCall, unsigned NumArgsToCheck) {
41+
for (unsigned i = 0; i < NumArgsToCheck; ++i) {
42+
ExprResult Arg = TheCall->getArg(i);
43+
QualType ArgTy = Arg.get()->getType();
44+
auto *VTy = ArgTy->getAs<VectorType>();
45+
if (VTy == nullptr) {
46+
SemaRef.Diag(Arg.get()->getBeginLoc(),
47+
diag::err_typecheck_convert_incompatible)
48+
<< ArgTy
49+
<< SemaRef.Context.getVectorType(ArgTy, 2, VectorKind::Generic) << 1
50+
<< 0 << 0;
51+
return true;
52+
}
53+
}
54+
return false;
55+
}
56+
3257
static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
3358
assert(TheCall->getNumArgs() > 1);
3459
QualType ArgTy0 = TheCall->getArg(0)->getType();
@@ -45,6 +70,7 @@ static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
4570
}
4671
return false;
4772
}
73+
bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
4874

4975
static std::optional<int>
5076
processConstant32BitIntArgument(Sema &SemaRef, CallExpr *Call, int Argument) {
@@ -157,122 +183,56 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
157183
if (SemaRef.checkArgCount(TheCall, 2))
158184
return true;
159185

160-
ExprResult A = TheCall->getArg(0);
161-
QualType ArgTyA = A.get()->getType();
162-
auto *VTyA = ArgTyA->getAs<VectorType>();
163-
if (VTyA == nullptr) {
164-
SemaRef.Diag(A.get()->getBeginLoc(),
165-
diag::err_typecheck_convert_incompatible)
166-
<< ArgTyA
167-
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
168-
<< 0 << 0;
169-
return true;
170-
}
171-
172-
ExprResult B = TheCall->getArg(1);
173-
QualType ArgTyB = B.get()->getType();
174-
auto *VTyB = ArgTyB->getAs<VectorType>();
175-
if (VTyB == nullptr) {
176-
SemaRef.Diag(B.get()->getBeginLoc(),
177-
diag::err_typecheck_convert_incompatible)
178-
<< ArgTyB
179-
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
180-
<< 0 << 0;
186+
// Use the helper function to check both arguments
187+
if (CheckVectorArgs(TheCall, 2))
181188
return true;
182-
}
183189

184-
QualType RetTy = VTyA->getElementType();
190+
QualType RetTy = TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
185191
TheCall->setType(RetTy);
186192
break;
187193
}
188194
case SPIRV::BI__builtin_spirv_length: {
189195
if (SemaRef.checkArgCount(TheCall, 1))
190196
return true;
191-
ExprResult A = TheCall->getArg(0);
192-
QualType ArgTyA = A.get()->getType();
193-
auto *VTy = ArgTyA->getAs<VectorType>();
194-
if (VTy == nullptr) {
195-
SemaRef.Diag(A.get()->getBeginLoc(),
196-
diag::err_typecheck_convert_incompatible)
197-
<< ArgTyA
198-
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
199-
<< 0 << 0;
197+
198+
// Use the helper function to check the argument
199+
if (CheckVectorArgs(TheCall, 1))
200200
return true;
201-
}
202-
QualType RetTy = VTy->getElementType();
201+
202+
QualType RetTy = TheCall->getArg(0)->getType()->getAs<VectorType>()->getElementType();
203203
TheCall->setType(RetTy);
204204
break;
205205
}
206206
case SPIRV::BI__builtin_spirv_refract: {
207207
if (SemaRef.checkArgCount(TheCall, 3))
208208
return true;
209209

210-
ExprResult A = TheCall->getArg(0);
211-
QualType ArgTyA = A.get()->getType();
212-
auto *VTyA = ArgTyA->getAs<VectorType>();
213-
if (VTyA == nullptr) {
214-
SemaRef.Diag(A.get()->getBeginLoc(),
215-
diag::err_typecheck_convert_incompatible)
216-
<< ArgTyA
217-
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
218-
<< 0 << 0;
210+
// Use the helper function to check the first two arguments
211+
if (CheckVectorArgs(TheCall, 2))
219212
return true;
220-
}
221-
222-
ExprResult B = TheCall->getArg(1);
223-
QualType ArgTyB = B.get()->getType();
224-
auto *VTyB = ArgTyB->getAs<VectorType>();
225-
if (VTyB == nullptr) {
226-
SemaRef.Diag(B.get()->getBeginLoc(),
227-
diag::err_typecheck_convert_incompatible)
228-
<< ArgTyB
229-
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
230-
<< 0 << 0;
231-
return true;
232-
}
233213

234214
ExprResult C = TheCall->getArg(2);
235215
QualType ArgTyC = C.get()->getType();
236-
if (!ArgTyC->hasFloatingRepresentation()) {
216+
if (!ArgTyC->isFloatingType()) {
237217
SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
238-
<< 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
218+
<< 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1
239219
<< ArgTyC;
240220
return true;
241221
}
242222

243-
QualType RetTy = ArgTyA;
223+
QualType RetTy = TheCall->getArg(0)->getType();
244224
TheCall->setType(RetTy);
245225
break;
246226
}
247227
case SPIRV::BI__builtin_spirv_reflect: {
248228
if (SemaRef.checkArgCount(TheCall, 2))
249229
return true;
250230

251-
ExprResult A = TheCall->getArg(0);
252-
QualType ArgTyA = A.get()->getType();
253-
auto *VTyA = ArgTyA->getAs<VectorType>();
254-
if (VTyA == nullptr) {
255-
SemaRef.Diag(A.get()->getBeginLoc(),
256-
diag::err_typecheck_convert_incompatible)
257-
<< ArgTyA
258-
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
259-
<< 0 << 0;
260-
return true;
261-
}
262-
263-
ExprResult B = TheCall->getArg(1);
264-
QualType ArgTyB = B.get()->getType();
265-
auto *VTyB = ArgTyB->getAs<VectorType>();
266-
if (VTyB == nullptr) {
267-
SemaRef.Diag(B.get()->getBeginLoc(),
268-
diag::err_typecheck_convert_incompatible)
269-
<< ArgTyB
270-
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
271-
<< 0 << 0;
231+
// Use the helper function to check both arguments
232+
if (CheckVectorArgs(TheCall, 2))
272233
return true;
273-
}
274234

275-
QualType RetTy = ArgTyA;
235+
QualType RetTy = TheCall->getArg(0)->getType();
276236
TheCall->setType(RetTy);
277237
break;
278238
}

clang/test/CodeGenHLSL/builtins/reflect.hlsl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
21
// RUN: %clang_cc1 -finclude-default-header -triple \
32
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
43
// RUN: -emit-llvm -O1 -o - | FileCheck %s

clang/test/CodeGenHLSL/builtins/refract.hlsl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
21
// RUN: %clang_cc1 -finclude-default-header -triple \
32
// RUN: dxil-pc-shadermodel6.3-library %s -fnative-half-type \
43
// RUN: -emit-llvm -O1 -o - | FileCheck %s
@@ -59,11 +58,10 @@ half test_refract_half(half I, half N, half ETA) {
5958
return refract(I, N, ETA);
6059
}
6160

62-
// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_
63-
// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
61+
// CHECK-LABEL: define noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_Dh(
62+
// CHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
6463
// CHECK-NEXT: [[ENTRY:.*:]]
65-
// CHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
66-
// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[CAST_VTRUNC]]
64+
// CHECK-NEXT: [[MUL_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]], [[ETA]]
6765
// CHECK-NEXT: [[HLSL_DOT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn half @llvm.dx.fdot.v2f16(<2 x half> [[N]], <2 x half> [[I]])
6866
// CHECK-NEXT: [[MUL_2_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[HLSL_DOT_I]]
6967
// CHECK-NEXT: [[SUB_I:%.*]] = fsub reassoc nnan ninf nsz arcp afn half 0xH3C00, [[MUL_2_I]]
@@ -76,7 +74,7 @@ half test_refract_half(half I, half N, half ETA) {
7674
// CHECK-NEXT: [[SPLAT_SPLAT_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT_I]], <2 x half> poison, <2 x i32> zeroinitializer
7775
// CHECK-NEXT: [[SPLAT_SPLAT6_I:%.*]] = shufflevector <2 x half> [[ETA]], <2 x half> poison, <2 x i32> zeroinitializer
7876
// CHECK-NEXT: [[MUL_7_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn <2 x half> [[SPLAT_SPLAT6_I]], [[I]]
79-
// CHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[CAST_VTRUNC]]
77+
// CHECK-NEXT: [[MUL_9_I:%.*]] = fmul reassoc nnan ninf nsz arcp afn half [[HLSL_DOT_I]], [[ETA]]
8078
// CHECK-NEXT: [[SPLAT_SPLATINSERT10_I:%.*]] = insertelement <2 x half> poison, half [[MUL_9_I]], i64 0
8179
// CHECK-NEXT: [[SPLAT_SPLAT11_I:%.*]] = shufflevector <2 x half> [[SPLAT_SPLATINSERT10_I]], <2 x half> poison, <2 x i32> zeroinitializer
8280
// CHECK-NEXT: [[TMP0:%.*]] = tail call reassoc nnan ninf nsz arcp afn <2 x half> @llvm.sqrt.v2f16(<2 x half> [[SPLAT_SPLAT_I]])
@@ -90,14 +88,13 @@ half test_refract_half(half I, half N, half ETA) {
9088

9189
//
9290
// SPVCHECK-LABEL: define spir_func noundef nofpclass(nan inf) <2 x half> @_Z18test_refract_half2Dv2_DhS_S_(
93-
// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], <2 x half> noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
91+
// SPVCHECK-SAME: <2 x half> noundef nofpclass(nan inf) [[I:%.*]], <2 x half> noundef nofpclass(nan inf) [[N:%.*]], half noundef nofpclass(nan inf) [[ETA:%.*]]) local_unnamed_addr #[[ATTR0:[0-9]+]] {
9492
// SPVCHECK-NEXT: [[ENTRY:.*:]]
95-
// SPVCHECK-NEXT: [[CAST_VTRUNC:%.*]] = extractelement <2 x half> [[ETA]], i64 0
96-
// SPVCHECK-NEXT: [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[CAST_VTRUNC]] to double
93+
// SPVCHECK-NEXT: [[CONV_I:%.*]] = fpext reassoc nnan ninf nsz arcp afn half [[ETA]] to double
9794
// SPVCHECK-NEXT: [[SPV_REFRACT_I:%.*]] = tail call reassoc nnan ninf nsz arcp afn noundef <2 x half> @llvm.spv.refract.v2f16.f64(<2 x half> [[I]], <2 x half> [[N]], double [[CONV_I]])
9895
// SPVCHECK-NEXT: ret <2 x half> [[SPV_REFRACT_I]]
9996
//
100-
half2 test_refract_half2(half2 I, half2 N, half2 ETA) {
97+
half2 test_refract_half2(half2 I, half2 N, half ETA) {
10198
return refract(I, N, ETA);
10299
}
103100

clang/test/CodeGenSPIRV/Builtins/refract.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 5
22

3-
// RUN: %clang_cc1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
3+
// RUN: %clang_cc1 -O1 -triple spirv-pc-vulkan-compute %s -emit-llvm -o - | FileCheck %s
44

55
typedef float float2 __attribute__((ext_vector_type(2)));
66
typedef float float3 __attribute__((ext_vector_type(3)));
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
; RUN: not llc -verify-machineinstrs -O0 -mtriple=spirv64-unknown-unknown %s -o /dev/null 2>&1 | FileCheck %s
2+
; RUN: not llc -verify-machineinstrs -O0 -mtriple=spirv32-unknown-unknown %s -o /dev/null 2>&1 | FileCheck %s
3+
4+
; CHECK: LLVM ERROR: %{{.*}} = G_INTRINSIC intrinsic(@llvm.spv.refract), %{{.*}}, %{{.*}}, %{{.*}} is only supported with the GLSL extended instruction set.
5+
6+
define noundef <4 x float> @refract_float4(<4 x float> noundef %I, <4 x float> noundef %N, float noundef %ETA) {
7+
entry:
8+
%spv.refract = call <4 x float> @llvm.spv.refract.f32(<4 x float> %I, <4 x float> %N, float %ETA)
9+
ret <4 x float> %spv.refract
10+
}
11+
12+
declare <4 x float> @llvm.spv.refract.f32(<4 x float>, <4 x float>, float)

0 commit comments

Comments
 (0)