@@ -29,6 +29,31 @@ namespace clang {
29
29
30
30
SemaSPIRV::SemaSPIRV (Sema &S) : SemaBase(S) {}
31
31
32
+ // / Checks if the first `NumArgsToCheck` arguments of a function call are of vector type.
33
+ // / If any of the arguments is not a vector type, it emits a diagnostic error and returns `true`.
34
+ // / Otherwise, it returns `false`.
35
+ // /
36
+ // / \param TheCall The function call expression to check.
37
+ // / \param NumArgsToCheck The number of arguments to check for vector type.
38
+ // / \return `true` if any of the arguments is not a vector type, `false` otherwise.
39
+
40
+ bool SemaSPIRV::CheckVectorArgs (CallExpr *TheCall, unsigned NumArgsToCheck) {
41
+ for (unsigned i = 0 ; i < NumArgsToCheck; ++i) {
42
+ ExprResult Arg = TheCall->getArg (i);
43
+ QualType ArgTy = Arg.get ()->getType ();
44
+ auto *VTy = ArgTy->getAs <VectorType>();
45
+ if (VTy == nullptr ) {
46
+ SemaRef.Diag (Arg.get ()->getBeginLoc (),
47
+ diag::err_typecheck_convert_incompatible)
48
+ << ArgTy
49
+ << SemaRef.Context .getVectorType (ArgTy, 2 , VectorKind::Generic) << 1
50
+ << 0 << 0 ;
51
+ return true ;
52
+ }
53
+ }
54
+ return false ;
55
+ }
56
+
32
57
static bool CheckAllArgsHaveSameType (Sema *S, CallExpr *TheCall) {
33
58
assert (TheCall->getNumArgs () > 1 );
34
59
QualType ArgTy0 = TheCall->getArg (0 )->getType ();
@@ -45,6 +70,7 @@ static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) {
45
70
}
46
71
return false ;
47
72
}
73
+ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall (unsigned BuiltinID,
48
74
49
75
static std::optional<int >
50
76
processConstant32BitIntArgument (Sema &SemaRef, CallExpr *Call, int Argument) {
@@ -157,122 +183,56 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(const TargetInfo &TI,
157
183
if (SemaRef.checkArgCount (TheCall, 2 ))
158
184
return true ;
159
185
160
- ExprResult A = TheCall->getArg (0 );
161
- QualType ArgTyA = A.get ()->getType ();
162
- auto *VTyA = ArgTyA->getAs <VectorType>();
163
- if (VTyA == nullptr ) {
164
- SemaRef.Diag (A.get ()->getBeginLoc (),
165
- diag::err_typecheck_convert_incompatible)
166
- << ArgTyA
167
- << SemaRef.Context .getVectorType (ArgTyA, 2 , VectorKind::Generic) << 1
168
- << 0 << 0 ;
169
- return true ;
170
- }
171
-
172
- ExprResult B = TheCall->getArg (1 );
173
- QualType ArgTyB = B.get ()->getType ();
174
- auto *VTyB = ArgTyB->getAs <VectorType>();
175
- if (VTyB == nullptr ) {
176
- SemaRef.Diag (B.get ()->getBeginLoc (),
177
- diag::err_typecheck_convert_incompatible)
178
- << ArgTyB
179
- << SemaRef.Context .getVectorType (ArgTyB, 2 , VectorKind::Generic) << 1
180
- << 0 << 0 ;
186
+ // Use the helper function to check both arguments
187
+ if (CheckVectorArgs (TheCall, 2 ))
181
188
return true ;
182
- }
183
189
184
- QualType RetTy = VTyA ->getElementType ();
190
+ QualType RetTy = TheCall-> getArg ( 0 )-> getType ()-> getAs <VectorType>() ->getElementType ();
185
191
TheCall->setType (RetTy);
186
192
break ;
187
193
}
188
194
case SPIRV::BI__builtin_spirv_length: {
189
195
if (SemaRef.checkArgCount (TheCall, 1 ))
190
196
return true ;
191
- ExprResult A = TheCall->getArg (0 );
192
- QualType ArgTyA = A.get ()->getType ();
193
- auto *VTy = ArgTyA->getAs <VectorType>();
194
- if (VTy == nullptr ) {
195
- SemaRef.Diag (A.get ()->getBeginLoc (),
196
- diag::err_typecheck_convert_incompatible)
197
- << ArgTyA
198
- << SemaRef.Context .getVectorType (ArgTyA, 2 , VectorKind::Generic) << 1
199
- << 0 << 0 ;
197
+
198
+ // Use the helper function to check the argument
199
+ if (CheckVectorArgs (TheCall, 1 ))
200
200
return true ;
201
- }
202
- QualType RetTy = VTy ->getElementType ();
201
+
202
+ QualType RetTy = TheCall-> getArg ( 0 )-> getType ()-> getAs <VectorType>() ->getElementType ();
203
203
TheCall->setType (RetTy);
204
204
break ;
205
205
}
206
206
case SPIRV::BI__builtin_spirv_refract: {
207
207
if (SemaRef.checkArgCount (TheCall, 3 ))
208
208
return true ;
209
209
210
- ExprResult A = TheCall->getArg (0 );
211
- QualType ArgTyA = A.get ()->getType ();
212
- auto *VTyA = ArgTyA->getAs <VectorType>();
213
- if (VTyA == nullptr ) {
214
- SemaRef.Diag (A.get ()->getBeginLoc (),
215
- diag::err_typecheck_convert_incompatible)
216
- << ArgTyA
217
- << SemaRef.Context .getVectorType (ArgTyA, 2 , VectorKind::Generic) << 1
218
- << 0 << 0 ;
210
+ // Use the helper function to check the first two arguments
211
+ if (CheckVectorArgs (TheCall, 2 ))
219
212
return true ;
220
- }
221
-
222
- ExprResult B = TheCall->getArg (1 );
223
- QualType ArgTyB = B.get ()->getType ();
224
- auto *VTyB = ArgTyB->getAs <VectorType>();
225
- if (VTyB == nullptr ) {
226
- SemaRef.Diag (B.get ()->getBeginLoc (),
227
- diag::err_typecheck_convert_incompatible)
228
- << ArgTyB
229
- << SemaRef.Context .getVectorType (ArgTyB, 2 , VectorKind::Generic) << 1
230
- << 0 << 0 ;
231
- return true ;
232
- }
233
213
234
214
ExprResult C = TheCall->getArg (2 );
235
215
QualType ArgTyC = C.get ()->getType ();
236
- if (!ArgTyC->hasFloatingRepresentation ()) {
216
+ if (!ArgTyC->isFloatingType ()) {
237
217
SemaRef.Diag (C.get ()->getBeginLoc (), diag::err_builtin_invalid_arg_type)
238
- << 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
218
+ << 3 << /* scalar*/ 5 << /* no int */ 0 << /* fp */ 1
239
219
<< ArgTyC;
240
220
return true ;
241
221
}
242
222
243
- QualType RetTy = ArgTyA ;
223
+ QualType RetTy = TheCall-> getArg ( 0 )-> getType () ;
244
224
TheCall->setType (RetTy);
245
225
break ;
246
226
}
247
227
case SPIRV::BI__builtin_spirv_reflect: {
248
228
if (SemaRef.checkArgCount (TheCall, 2 ))
249
229
return true ;
250
230
251
- ExprResult A = TheCall->getArg (0 );
252
- QualType ArgTyA = A.get ()->getType ();
253
- auto *VTyA = ArgTyA->getAs <VectorType>();
254
- if (VTyA == nullptr ) {
255
- SemaRef.Diag (A.get ()->getBeginLoc (),
256
- diag::err_typecheck_convert_incompatible)
257
- << ArgTyA
258
- << SemaRef.Context .getVectorType (ArgTyA, 2 , VectorKind::Generic) << 1
259
- << 0 << 0 ;
260
- return true ;
261
- }
262
-
263
- ExprResult B = TheCall->getArg (1 );
264
- QualType ArgTyB = B.get ()->getType ();
265
- auto *VTyB = ArgTyB->getAs <VectorType>();
266
- if (VTyB == nullptr ) {
267
- SemaRef.Diag (B.get ()->getBeginLoc (),
268
- diag::err_typecheck_convert_incompatible)
269
- << ArgTyB
270
- << SemaRef.Context .getVectorType (ArgTyB, 2 , VectorKind::Generic) << 1
271
- << 0 << 0 ;
231
+ // Use the helper function to check both arguments
232
+ if (CheckVectorArgs (TheCall, 2 ))
272
233
return true ;
273
- }
274
234
275
- QualType RetTy = ArgTyA ;
235
+ QualType RetTy = TheCall-> getArg ( 0 )-> getType () ;
276
236
TheCall->setType (RetTy);
277
237
break ;
278
238
}
0 commit comments