Skip to content

Commit 8a9aa18

Browse files
authored
[Clang][FIX] Fix type qualifiers on vector builtins (#160185)
Summary: These were not stripping qualifiers when using them to infer the types, leading to errors when mixiing const and non-const.
1 parent bad92c9 commit 8a9aa18

File tree

6 files changed

+276
-22
lines changed

6 files changed

+276
-22
lines changed

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4289,7 +4289,7 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
42894289
llvm::Value *Result;
42904290
if (BuiltinID == Builtin::BI__builtin_masked_load) {
42914291
Function *F =
4292-
CGM.getIntrinsic(Intrinsic::masked_load, {RetTy, UnqualPtrTy});
4292+
CGM.getIntrinsic(Intrinsic::masked_load, {RetTy, Ptr->getType()});
42934293
Result =
42944294
Builder.CreateCall(F, {Ptr, AlignVal, Mask, PassThru}, "masked_load");
42954295
} else {
@@ -4334,7 +4334,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
43344334

43354335
QualType ValTy = E->getArg(1)->getType();
43364336
llvm::Type *ValLLTy = CGM.getTypes().ConvertType(ValTy);
4337-
llvm::Type *PtrTy = Ptr->getType();
43384337

43394338
CharUnits Align = CGM.getNaturalTypeAlignment(
43404339
E->getArg(1)->getType()->getAs<VectorType>()->getElementType(),
@@ -4343,8 +4342,8 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
43434342
llvm::ConstantInt::get(Int32Ty, Align.getQuantity());
43444343

43454344
if (BuiltinID == Builtin::BI__builtin_masked_store) {
4346-
llvm::Function *F =
4347-
CGM.getIntrinsic(llvm::Intrinsic::masked_store, {ValLLTy, PtrTy});
4345+
llvm::Function *F = CGM.getIntrinsic(llvm::Intrinsic::masked_store,
4346+
{ValLLTy, Ptr->getType()});
43484347
Builder.CreateCall(F, {Val, Ptr, AlignVal, Mask});
43494348
} else {
43504349
llvm::Function *F =

clang/lib/Sema/SemaChecking.cpp

Lines changed: 74 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2268,7 +2268,8 @@ static bool BuiltinCountZeroBitsGeneric(Sema &S, CallExpr *TheCall) {
22682268
}
22692269

22702270
static bool CheckMaskedBuiltinArgs(Sema &S, Expr *MaskArg, Expr *PtrArg,
2271-
unsigned Pos) {
2271+
unsigned Pos, bool AllowConst,
2272+
bool AllowAS) {
22722273
QualType MaskTy = MaskArg->getType();
22732274
if (!MaskTy->isExtVectorBoolType())
22742275
return S.Diag(MaskArg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
@@ -2279,25 +2280,62 @@ static bool CheckMaskedBuiltinArgs(Sema &S, Expr *MaskArg, Expr *PtrArg,
22792280
if (!PtrTy->isPointerType() || PtrTy->getPointeeType()->isVectorType())
22802281
return S.Diag(PtrArg->getExprLoc(), diag::err_vec_masked_load_store_ptr)
22812282
<< Pos << "scalar pointer";
2283+
2284+
QualType PointeeTy = PtrTy->getPointeeType();
2285+
if (PointeeTy.isVolatileQualified() || PointeeTy->isAtomicType() ||
2286+
(!AllowConst && PointeeTy.isConstQualified()) ||
2287+
(!AllowAS && PointeeTy.hasAddressSpace())) {
2288+
QualType Target =
2289+
S.Context.getPointerType(PointeeTy.getAtomicUnqualifiedType());
2290+
return S.Diag(PtrArg->getExprLoc(),
2291+
diag::err_typecheck_convert_incompatible)
2292+
<< PtrTy << Target << /*different qualifiers=*/5
2293+
<< /*qualifier difference=*/0 << /*parameter mismatch=*/3 << 2
2294+
<< PtrTy << Target;
2295+
}
2296+
return false;
2297+
}
2298+
2299+
static bool ConvertMaskedBuiltinArgs(Sema &S, CallExpr *TheCall) {
2300+
bool TypeDependent = false;
2301+
for (unsigned Arg = 0, E = TheCall->getNumArgs(); Arg != E; ++Arg) {
2302+
ExprResult Converted =
2303+
S.DefaultFunctionArrayLvalueConversion(TheCall->getArg(Arg));
2304+
if (Converted.isInvalid())
2305+
return true;
2306+
TheCall->setArg(Arg, Converted.get());
2307+
TypeDependent |= Converted.get()->isTypeDependent();
2308+
}
2309+
2310+
if (TypeDependent)
2311+
TheCall->setType(S.Context.DependentTy);
22822312
return false;
22832313
}
22842314

22852315
static ExprResult BuiltinMaskedLoad(Sema &S, CallExpr *TheCall) {
22862316
if (S.checkArgCountRange(TheCall, 2, 3))
22872317
return ExprError();
22882318

2319+
if (ConvertMaskedBuiltinArgs(S, TheCall))
2320+
return ExprError();
2321+
22892322
Expr *MaskArg = TheCall->getArg(0);
22902323
Expr *PtrArg = TheCall->getArg(1);
2291-
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 2))
2324+
if (TheCall->isTypeDependent())
2325+
return TheCall;
2326+
2327+
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 2, /*AllowConst=*/true,
2328+
TheCall->getBuiltinCallee() ==
2329+
Builtin::BI__builtin_masked_load))
22922330
return ExprError();
22932331

22942332
QualType MaskTy = MaskArg->getType();
22952333
QualType PtrTy = PtrArg->getType();
22962334
QualType PointeeTy = PtrTy->getPointeeType();
22972335
const VectorType *MaskVecTy = MaskTy->getAs<VectorType>();
22982336

2299-
QualType RetTy =
2300-
S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());
2337+
QualType RetTy = S.Context.getExtVectorType(PointeeTy.getUnqualifiedType(),
2338+
MaskVecTy->getNumElements());
23012339
if (TheCall->getNumArgs() == 3) {
23022340
Expr *PassThruArg = TheCall->getArg(2);
23032341
QualType PassThruTy = PassThruArg->getType();
@@ -2314,11 +2352,18 @@ static ExprResult BuiltinMaskedStore(Sema &S, CallExpr *TheCall) {
23142352
if (S.checkArgCount(TheCall, 3))
23152353
return ExprError();
23162354

2355+
if (ConvertMaskedBuiltinArgs(S, TheCall))
2356+
return ExprError();
2357+
23172358
Expr *MaskArg = TheCall->getArg(0);
23182359
Expr *ValArg = TheCall->getArg(1);
23192360
Expr *PtrArg = TheCall->getArg(2);
2361+
if (TheCall->isTypeDependent())
2362+
return TheCall;
23202363

2321-
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3))
2364+
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3, /*AllowConst=*/false,
2365+
TheCall->getBuiltinCallee() ==
2366+
Builtin::BI__builtin_masked_store))
23222367
return ExprError();
23232368

23242369
QualType MaskTy = MaskArg->getType();
@@ -2331,10 +2376,10 @@ static ExprResult BuiltinMaskedStore(Sema &S, CallExpr *TheCall) {
23312376

23322377
QualType PointeeTy = PtrTy->getPointeeType();
23332378
const VectorType *MaskVecTy = MaskTy->getAs<VectorType>();
2334-
QualType RetTy =
2335-
S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());
2336-
2337-
if (!S.Context.hasSameType(ValTy, RetTy))
2379+
QualType MemoryTy = S.Context.getExtVectorType(PointeeTy.getUnqualifiedType(),
2380+
MaskVecTy->getNumElements());
2381+
if (!S.Context.hasSameType(ValTy.getUnqualifiedType(),
2382+
MemoryTy.getUnqualifiedType()))
23382383
return ExprError(S.Diag(TheCall->getBeginLoc(),
23392384
diag::err_vec_builtin_incompatible_vector)
23402385
<< TheCall->getDirectCallee() << /*isMorethantwoArgs*/ 2
@@ -2349,10 +2394,17 @@ static ExprResult BuiltinMaskedGather(Sema &S, CallExpr *TheCall) {
23492394
if (S.checkArgCountRange(TheCall, 3, 4))
23502395
return ExprError();
23512396

2397+
if (ConvertMaskedBuiltinArgs(S, TheCall))
2398+
return ExprError();
2399+
23522400
Expr *MaskArg = TheCall->getArg(0);
23532401
Expr *IdxArg = TheCall->getArg(1);
23542402
Expr *PtrArg = TheCall->getArg(2);
2355-
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3))
2403+
if (TheCall->isTypeDependent())
2404+
return TheCall;
2405+
2406+
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3, /*AllowConst=*/true,
2407+
/*AllowAS=*/true))
23562408
return ExprError();
23572409

23582410
QualType IdxTy = IdxArg->getType();
@@ -2373,8 +2425,8 @@ static ExprResult BuiltinMaskedGather(Sema &S, CallExpr *TheCall) {
23732425
TheCall->getBuiltinCallee())
23742426
<< MaskTy << IdxTy);
23752427

2376-
QualType RetTy =
2377-
S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());
2428+
QualType RetTy = S.Context.getExtVectorType(PointeeTy.getUnqualifiedType(),
2429+
MaskVecTy->getNumElements());
23782430
if (TheCall->getNumArgs() == 4) {
23792431
Expr *PassThruArg = TheCall->getArg(3);
23802432
QualType PassThruTy = PassThruArg->getType();
@@ -2392,12 +2444,18 @@ static ExprResult BuiltinMaskedScatter(Sema &S, CallExpr *TheCall) {
23922444
if (S.checkArgCount(TheCall, 4))
23932445
return ExprError();
23942446

2447+
if (ConvertMaskedBuiltinArgs(S, TheCall))
2448+
return ExprError();
2449+
23952450
Expr *MaskArg = TheCall->getArg(0);
23962451
Expr *IdxArg = TheCall->getArg(1);
23972452
Expr *ValArg = TheCall->getArg(2);
23982453
Expr *PtrArg = TheCall->getArg(3);
2454+
if (TheCall->isTypeDependent())
2455+
return TheCall;
23992456

2400-
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3))
2457+
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 4, /*AllowConst=*/false,
2458+
/*AllowAS=*/true))
24012459
return ExprError();
24022460

24032461
QualType IdxTy = IdxArg->getType();
@@ -2427,9 +2485,9 @@ static ExprResult BuiltinMaskedScatter(Sema &S, CallExpr *TheCall) {
24272485
TheCall->getBuiltinCallee())
24282486
<< MaskTy << ValTy);
24292487

2430-
QualType ArgTy =
2431-
S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());
2432-
if (!S.Context.hasSameType(ValTy, ArgTy))
2488+
QualType ArgTy = S.Context.getExtVectorType(PointeeTy.getUnqualifiedType(),
2489+
MaskVecTy->getNumElements());
2490+
if (!S.Context.hasSameType(ValTy.getUnqualifiedType(), ArgTy))
24332491
return ExprError(S.Diag(TheCall->getBeginLoc(),
24342492
diag::err_vec_builtin_incompatible_vector)
24352493
<< TheCall->getDirectCallee() << /*isMoreThanTwoArgs*/ 2

clang/test/CodeGen/builtin-masked.c

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,109 @@ v8i test_gather(v8b mask, v8i idx, int *ptr) {
187187
void test_scatter(v8b mask, v8i val, v8i idx, int *ptr) {
188188
__builtin_masked_scatter(mask, val, idx, ptr);
189189
}
190+
191+
// CHECK-LABEL: define dso_local <8 x i32> @test_load_as(
192+
// CHECK-SAME: i8 noundef [[MASK_COERCE:%.*]], ptr addrspace(42) noundef [[PTR:%.*]]) #[[ATTR0]] {
193+
// CHECK-NEXT: [[ENTRY:.*:]]
194+
// CHECK-NEXT: [[MASK:%.*]] = alloca i8, align 1
195+
// CHECK-NEXT: [[MASK_ADDR:%.*]] = alloca i8, align 1
196+
// CHECK-NEXT: [[PTR_ADDR:%.*]] = alloca ptr addrspace(42), align 8
197+
// CHECK-NEXT: store i8 [[MASK_COERCE]], ptr [[MASK]], align 1
198+
// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[MASK]], align 1
199+
// CHECK-NEXT: [[MASK1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
200+
// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[MASK1]] to i8
201+
// CHECK-NEXT: store i8 [[TMP0]], ptr [[MASK_ADDR]], align 1
202+
// CHECK-NEXT: store ptr addrspace(42) [[PTR]], ptr [[PTR_ADDR]], align 8
203+
// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[MASK_ADDR]], align 1
204+
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
205+
// CHECK-NEXT: [[TMP2:%.*]] = load ptr addrspace(42), ptr [[PTR_ADDR]], align 8
206+
// CHECK-NEXT: [[MASKED_LOAD:%.*]] = call <8 x i32> @llvm.masked.load.v8i32.p42(ptr addrspace(42) [[TMP2]], i32 4, <8 x i1> [[TMP1]], <8 x i32> poison)
207+
// CHECK-NEXT: ret <8 x i32> [[MASKED_LOAD]]
208+
//
209+
v8i test_load_as(v8b mask, int __attribute__((address_space(42))) * ptr) {
210+
return __builtin_masked_load(mask, ptr);
211+
}
212+
213+
// CHECK-LABEL: define dso_local void @test_store_as(
214+
// CHECK-SAME: i8 noundef [[M_COERCE:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP0:%.*]], ptr addrspace(42) noundef [[P:%.*]]) #[[ATTR3]] {
215+
// CHECK-NEXT: [[ENTRY:.*:]]
216+
// CHECK-NEXT: [[M:%.*]] = alloca i8, align 1
217+
// CHECK-NEXT: [[M_ADDR:%.*]] = alloca i8, align 1
218+
// CHECK-NEXT: [[V_ADDR:%.*]] = alloca <8 x i32>, align 32
219+
// CHECK-NEXT: [[P_ADDR:%.*]] = alloca ptr addrspace(42), align 8
220+
// CHECK-NEXT: store i8 [[M_COERCE]], ptr [[M]], align 1
221+
// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[M]], align 1
222+
// CHECK-NEXT: [[M1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
223+
// CHECK-NEXT: [[V:%.*]] = load <8 x i32>, ptr [[TMP0]], align 32
224+
// CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[M1]] to i8
225+
// CHECK-NEXT: store i8 [[TMP1]], ptr [[M_ADDR]], align 1
226+
// CHECK-NEXT: store <8 x i32> [[V]], ptr [[V_ADDR]], align 32
227+
// CHECK-NEXT: store ptr addrspace(42) [[P]], ptr [[P_ADDR]], align 8
228+
// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[M_ADDR]], align 1
229+
// CHECK-NEXT: [[TMP2:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
230+
// CHECK-NEXT: [[TMP3:%.*]] = load <8 x i32>, ptr [[V_ADDR]], align 32
231+
// CHECK-NEXT: [[TMP4:%.*]] = load ptr addrspace(42), ptr [[P_ADDR]], align 8
232+
// CHECK-NEXT: call void @llvm.masked.store.v8i32.p42(<8 x i32> [[TMP3]], ptr addrspace(42) [[TMP4]], i32 4, <8 x i1> [[TMP2]])
233+
// CHECK-NEXT: ret void
234+
//
235+
void test_store_as(v8b m, v8i v, int __attribute__((address_space(42))) *p) {
236+
__builtin_masked_store(m, v, p);
237+
}
238+
239+
// CHECK-LABEL: define dso_local <8 x i32> @test_gather_as(
240+
// CHECK-SAME: i8 noundef [[MASK_COERCE:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP0:%.*]], ptr addrspace(42) noundef [[PTR:%.*]]) #[[ATTR0]] {
241+
// CHECK-NEXT: [[ENTRY:.*:]]
242+
// CHECK-NEXT: [[MASK:%.*]] = alloca i8, align 1
243+
// CHECK-NEXT: [[MASK_ADDR:%.*]] = alloca i8, align 1
244+
// CHECK-NEXT: [[IDX_ADDR:%.*]] = alloca <8 x i32>, align 32
245+
// CHECK-NEXT: [[PTR_ADDR:%.*]] = alloca ptr addrspace(42), align 8
246+
// CHECK-NEXT: store i8 [[MASK_COERCE]], ptr [[MASK]], align 1
247+
// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[MASK]], align 1
248+
// CHECK-NEXT: [[MASK1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
249+
// CHECK-NEXT: [[IDX:%.*]] = load <8 x i32>, ptr [[TMP0]], align 32
250+
// CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[MASK1]] to i8
251+
// CHECK-NEXT: store i8 [[TMP1]], ptr [[MASK_ADDR]], align 1
252+
// CHECK-NEXT: store <8 x i32> [[IDX]], ptr [[IDX_ADDR]], align 32
253+
// CHECK-NEXT: store ptr addrspace(42) [[PTR]], ptr [[PTR_ADDR]], align 8
254+
// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[MASK_ADDR]], align 1
255+
// CHECK-NEXT: [[TMP2:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
256+
// CHECK-NEXT: [[TMP3:%.*]] = load <8 x i32>, ptr [[IDX_ADDR]], align 32
257+
// CHECK-NEXT: [[TMP4:%.*]] = load ptr addrspace(42), ptr [[PTR_ADDR]], align 8
258+
// CHECK-NEXT: [[TMP5:%.*]] = getelementptr i32, ptr addrspace(42) [[TMP4]], <8 x i32> [[TMP3]]
259+
// CHECK-NEXT: [[MASKED_GATHER:%.*]] = call <8 x i32> @llvm.masked.gather.v8i32.v8p42(<8 x ptr addrspace(42)> [[TMP5]], i32 4, <8 x i1> [[TMP2]], <8 x i32> poison)
260+
// CHECK-NEXT: ret <8 x i32> [[MASKED_GATHER]]
261+
//
262+
v8i test_gather_as(v8b mask, v8i idx, int __attribute__((address_space(42))) *ptr) {
263+
return __builtin_masked_gather(mask, idx, ptr);
264+
}
265+
266+
// CHECK-LABEL: define dso_local void @test_scatter_as(
267+
// CHECK-SAME: i8 noundef [[MASK_COERCE:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP0:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP1:%.*]], ptr addrspace(42) noundef [[PTR:%.*]]) #[[ATTR3]] {
268+
// CHECK-NEXT: [[ENTRY:.*:]]
269+
// CHECK-NEXT: [[MASK:%.*]] = alloca i8, align 1
270+
// CHECK-NEXT: [[MASK_ADDR:%.*]] = alloca i8, align 1
271+
// CHECK-NEXT: [[VAL_ADDR:%.*]] = alloca <8 x i32>, align 32
272+
// CHECK-NEXT: [[IDX_ADDR:%.*]] = alloca <8 x i32>, align 32
273+
// CHECK-NEXT: [[PTR_ADDR:%.*]] = alloca ptr addrspace(42), align 8
274+
// CHECK-NEXT: store i8 [[MASK_COERCE]], ptr [[MASK]], align 1
275+
// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[MASK]], align 1
276+
// CHECK-NEXT: [[MASK1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
277+
// CHECK-NEXT: [[VAL:%.*]] = load <8 x i32>, ptr [[TMP0]], align 32
278+
// CHECK-NEXT: [[IDX:%.*]] = load <8 x i32>, ptr [[TMP1]], align 32
279+
// CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i1> [[MASK1]] to i8
280+
// CHECK-NEXT: store i8 [[TMP2]], ptr [[MASK_ADDR]], align 1
281+
// CHECK-NEXT: store <8 x i32> [[VAL]], ptr [[VAL_ADDR]], align 32
282+
// CHECK-NEXT: store <8 x i32> [[IDX]], ptr [[IDX_ADDR]], align 32
283+
// CHECK-NEXT: store ptr addrspace(42) [[PTR]], ptr [[PTR_ADDR]], align 8
284+
// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[MASK_ADDR]], align 1
285+
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
286+
// CHECK-NEXT: [[TMP4:%.*]] = load <8 x i32>, ptr [[VAL_ADDR]], align 32
287+
// CHECK-NEXT: [[TMP5:%.*]] = load <8 x i32>, ptr [[IDX_ADDR]], align 32
288+
// CHECK-NEXT: [[TMP6:%.*]] = load ptr addrspace(42), ptr [[PTR_ADDR]], align 8
289+
// CHECK-NEXT: [[TMP7:%.*]] = getelementptr i32, ptr addrspace(42) [[TMP6]], <8 x i32> [[TMP4]]
290+
// CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p42(<8 x i32> [[TMP5]], <8 x ptr addrspace(42)> [[TMP7]], i32 4, <8 x i1> [[TMP3]])
291+
// CHECK-NEXT: ret void
292+
//
293+
void test_scatter_as(v8b mask, v8i val, v8i idx, int __attribute__((address_space(42))) *ptr) {
294+
__builtin_masked_scatter(mask, val, idx, ptr);
295+
}

0 commit comments

Comments
 (0)