Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4289,7 +4289,7 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
llvm::Value *Result;
if (BuiltinID == Builtin::BI__builtin_masked_load) {
Function *F =
CGM.getIntrinsic(Intrinsic::masked_load, {RetTy, UnqualPtrTy});
CGM.getIntrinsic(Intrinsic::masked_load, {RetTy, Ptr->getType()});
Result =
Builder.CreateCall(F, {Ptr, AlignVal, Mask, PassThru}, "masked_load");
} else {
Expand Down Expand Up @@ -4334,7 +4334,6 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,

QualType ValTy = E->getArg(1)->getType();
llvm::Type *ValLLTy = CGM.getTypes().ConvertType(ValTy);
llvm::Type *PtrTy = Ptr->getType();

CharUnits Align = CGM.getNaturalTypeAlignment(
E->getArg(1)->getType()->getAs<VectorType>()->getElementType(),
Expand All @@ -4343,8 +4342,8 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
llvm::ConstantInt::get(Int32Ty, Align.getQuantity());

if (BuiltinID == Builtin::BI__builtin_masked_store) {
llvm::Function *F =
CGM.getIntrinsic(llvm::Intrinsic::masked_store, {ValLLTy, PtrTy});
llvm::Function *F = CGM.getIntrinsic(llvm::Intrinsic::masked_store,
{ValLLTy, Ptr->getType()});
Builder.CreateCall(F, {Val, Ptr, AlignVal, Mask});
} else {
llvm::Function *F =
Expand Down
90 changes: 74 additions & 16 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2268,7 +2268,8 @@ static bool BuiltinCountZeroBitsGeneric(Sema &S, CallExpr *TheCall) {
}

static bool CheckMaskedBuiltinArgs(Sema &S, Expr *MaskArg, Expr *PtrArg,
unsigned Pos) {
unsigned Pos, bool AllowConst,
bool AllowAS) {
QualType MaskTy = MaskArg->getType();
if (!MaskTy->isExtVectorBoolType())
return S.Diag(MaskArg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
Expand All @@ -2279,25 +2280,62 @@ static bool CheckMaskedBuiltinArgs(Sema &S, Expr *MaskArg, Expr *PtrArg,
if (!PtrTy->isPointerType() || PtrTy->getPointeeType()->isVectorType())
return S.Diag(PtrArg->getExprLoc(), diag::err_vec_masked_load_store_ptr)
<< Pos << "scalar pointer";

QualType PointeeTy = PtrTy->getPointeeType();
if (PointeeTy.isVolatileQualified() || PointeeTy->isAtomicType() ||
(!AllowConst && PointeeTy.isConstQualified()) ||
(!AllowAS && PointeeTy.hasAddressSpace())) {
QualType Target =
S.Context.getPointerType(PointeeTy.getAtomicUnqualifiedType());
return S.Diag(PtrArg->getExprLoc(),
diag::err_typecheck_convert_incompatible)
<< PtrTy << Target << /*different qualifiers=*/5
<< /*qualifier difference=*/0 << /*parameter mismatch=*/3 << 2
<< PtrTy << Target;
}
return false;
}

static bool ConvertMaskedBuiltinArgs(Sema &S, CallExpr *TheCall) {
bool TypeDependent = false;
for (unsigned Arg = 0, E = TheCall->getNumArgs(); Arg != E; ++Arg) {
ExprResult Converted =
S.DefaultFunctionArrayLvalueConversion(TheCall->getArg(Arg));
if (Converted.isInvalid())
return true;
TheCall->setArg(Arg, Converted.get());
TypeDependent |= Converted.get()->isTypeDependent();
}

if (TypeDependent)
TheCall->setType(S.Context.DependentTy);
return false;
}

static ExprResult BuiltinMaskedLoad(Sema &S, CallExpr *TheCall) {
if (S.checkArgCountRange(TheCall, 2, 3))
return ExprError();

if (ConvertMaskedBuiltinArgs(S, TheCall))
return ExprError();

Expr *MaskArg = TheCall->getArg(0);
Expr *PtrArg = TheCall->getArg(1);
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 2))
if (TheCall->isTypeDependent())
return TheCall;

if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 2, /*AllowConst=*/true,
TheCall->getBuiltinCallee() ==
Builtin::BI__builtin_masked_load))
return ExprError();

QualType MaskTy = MaskArg->getType();
QualType PtrTy = PtrArg->getType();
QualType PointeeTy = PtrTy->getPointeeType();
const VectorType *MaskVecTy = MaskTy->getAs<VectorType>();

QualType RetTy =
S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());
QualType RetTy = S.Context.getExtVectorType(PointeeTy.getUnqualifiedType(),
MaskVecTy->getNumElements());
if (TheCall->getNumArgs() == 3) {
Expr *PassThruArg = TheCall->getArg(2);
QualType PassThruTy = PassThruArg->getType();
Expand All @@ -2314,11 +2352,18 @@ static ExprResult BuiltinMaskedStore(Sema &S, CallExpr *TheCall) {
if (S.checkArgCount(TheCall, 3))
return ExprError();

if (ConvertMaskedBuiltinArgs(S, TheCall))
return ExprError();

Expr *MaskArg = TheCall->getArg(0);
Expr *ValArg = TheCall->getArg(1);
Expr *PtrArg = TheCall->getArg(2);
if (TheCall->isTypeDependent())
return TheCall;

if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3))
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3, /*AllowConst=*/false,
TheCall->getBuiltinCallee() ==
Builtin::BI__builtin_masked_store))
return ExprError();

QualType MaskTy = MaskArg->getType();
Expand All @@ -2331,10 +2376,10 @@ static ExprResult BuiltinMaskedStore(Sema &S, CallExpr *TheCall) {

QualType PointeeTy = PtrTy->getPointeeType();
const VectorType *MaskVecTy = MaskTy->getAs<VectorType>();
QualType RetTy =
S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());

if (!S.Context.hasSameType(ValTy, RetTy))
QualType MemoryTy = S.Context.getExtVectorType(PointeeTy.getUnqualifiedType(),
MaskVecTy->getNumElements());
if (!S.Context.hasSameType(ValTy.getUnqualifiedType(),
MemoryTy.getUnqualifiedType()))
return ExprError(S.Diag(TheCall->getBeginLoc(),
diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee() << /*isMorethantwoArgs*/ 2
Expand All @@ -2349,10 +2394,17 @@ static ExprResult BuiltinMaskedGather(Sema &S, CallExpr *TheCall) {
if (S.checkArgCountRange(TheCall, 3, 4))
return ExprError();

if (ConvertMaskedBuiltinArgs(S, TheCall))
return ExprError();

Expr *MaskArg = TheCall->getArg(0);
Expr *IdxArg = TheCall->getArg(1);
Expr *PtrArg = TheCall->getArg(2);
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3))
if (TheCall->isTypeDependent())
return TheCall;

if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3, /*AllowConst=*/true,
/*AllowAS=*/true))
return ExprError();

QualType IdxTy = IdxArg->getType();
Expand All @@ -2373,8 +2425,8 @@ static ExprResult BuiltinMaskedGather(Sema &S, CallExpr *TheCall) {
TheCall->getBuiltinCallee())
<< MaskTy << IdxTy);

QualType RetTy =
S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());
QualType RetTy = S.Context.getExtVectorType(PointeeTy.getUnqualifiedType(),
MaskVecTy->getNumElements());
if (TheCall->getNumArgs() == 4) {
Expr *PassThruArg = TheCall->getArg(3);
QualType PassThruTy = PassThruArg->getType();
Expand All @@ -2392,12 +2444,18 @@ static ExprResult BuiltinMaskedScatter(Sema &S, CallExpr *TheCall) {
if (S.checkArgCount(TheCall, 4))
return ExprError();

if (ConvertMaskedBuiltinArgs(S, TheCall))
return ExprError();

Expr *MaskArg = TheCall->getArg(0);
Expr *IdxArg = TheCall->getArg(1);
Expr *ValArg = TheCall->getArg(2);
Expr *PtrArg = TheCall->getArg(3);
if (TheCall->isTypeDependent())
return TheCall;

if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 3))
if (CheckMaskedBuiltinArgs(S, MaskArg, PtrArg, 4, /*AllowConst=*/false,
/*AllowAS=*/true))
return ExprError();

QualType IdxTy = IdxArg->getType();
Expand Down Expand Up @@ -2427,9 +2485,9 @@ static ExprResult BuiltinMaskedScatter(Sema &S, CallExpr *TheCall) {
TheCall->getBuiltinCallee())
<< MaskTy << ValTy);

QualType ArgTy =
S.Context.getExtVectorType(PointeeTy, MaskVecTy->getNumElements());
if (!S.Context.hasSameType(ValTy, ArgTy))
QualType ArgTy = S.Context.getExtVectorType(PointeeTy.getUnqualifiedType(),
MaskVecTy->getNumElements());
if (!S.Context.hasSameType(ValTy.getUnqualifiedType(), ArgTy))
return ExprError(S.Diag(TheCall->getBeginLoc(),
diag::err_vec_builtin_incompatible_vector)
<< TheCall->getDirectCallee() << /*isMoreThanTwoArgs*/ 2
Expand Down
106 changes: 106 additions & 0 deletions clang/test/CodeGen/builtin-masked.c
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,109 @@ v8i test_gather(v8b mask, v8i idx, int *ptr) {
void test_scatter(v8b mask, v8i val, v8i idx, int *ptr) {
__builtin_masked_scatter(mask, val, idx, ptr);
}

// CHECK-LABEL: define dso_local <8 x i32> @test_load_as(
// CHECK-SAME: i8 noundef [[MASK_COERCE:%.*]], ptr addrspace(42) noundef [[PTR:%.*]]) #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[MASK:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[MASK_ADDR:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[PTR_ADDR:%.*]] = alloca ptr addrspace(42), align 8
// CHECK-NEXT: store i8 [[MASK_COERCE]], ptr [[MASK]], align 1
// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[MASK]], align 1
// CHECK-NEXT: [[MASK1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
// CHECK-NEXT: [[TMP0:%.*]] = bitcast <8 x i1> [[MASK1]] to i8
// CHECK-NEXT: store i8 [[TMP0]], ptr [[MASK_ADDR]], align 1
// CHECK-NEXT: store ptr addrspace(42) [[PTR]], ptr [[PTR_ADDR]], align 8
// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[MASK_ADDR]], align 1
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
// CHECK-NEXT: [[TMP2:%.*]] = load ptr addrspace(42), ptr [[PTR_ADDR]], align 8
// CHECK-NEXT: [[MASKED_LOAD:%.*]] = call <8 x i32> @llvm.masked.load.v8i32.p42(ptr addrspace(42) [[TMP2]], i32 4, <8 x i1> [[TMP1]], <8 x i32> poison)
// CHECK-NEXT: ret <8 x i32> [[MASKED_LOAD]]
//
v8i test_load_as(v8b mask, int __attribute__((address_space(42))) * ptr) {
return __builtin_masked_load(mask, ptr);
}

// CHECK-LABEL: define dso_local void @test_store_as(
// CHECK-SAME: i8 noundef [[M_COERCE:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP0:%.*]], ptr addrspace(42) noundef [[P:%.*]]) #[[ATTR3]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[M:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[M_ADDR:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[V_ADDR:%.*]] = alloca <8 x i32>, align 32
// CHECK-NEXT: [[P_ADDR:%.*]] = alloca ptr addrspace(42), align 8
// CHECK-NEXT: store i8 [[M_COERCE]], ptr [[M]], align 1
// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[M]], align 1
// CHECK-NEXT: [[M1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
// CHECK-NEXT: [[V:%.*]] = load <8 x i32>, ptr [[TMP0]], align 32
// CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[M1]] to i8
// CHECK-NEXT: store i8 [[TMP1]], ptr [[M_ADDR]], align 1
// CHECK-NEXT: store <8 x i32> [[V]], ptr [[V_ADDR]], align 32
// CHECK-NEXT: store ptr addrspace(42) [[P]], ptr [[P_ADDR]], align 8
// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[M_ADDR]], align 1
// CHECK-NEXT: [[TMP2:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
// CHECK-NEXT: [[TMP3:%.*]] = load <8 x i32>, ptr [[V_ADDR]], align 32
// CHECK-NEXT: [[TMP4:%.*]] = load ptr addrspace(42), ptr [[P_ADDR]], align 8
// CHECK-NEXT: call void @llvm.masked.store.v8i32.p42(<8 x i32> [[TMP3]], ptr addrspace(42) [[TMP4]], i32 4, <8 x i1> [[TMP2]])
// CHECK-NEXT: ret void
//
void test_store_as(v8b m, v8i v, int __attribute__((address_space(42))) *p) {
__builtin_masked_store(m, v, p);
}

// CHECK-LABEL: define dso_local <8 x i32> @test_gather_as(
// CHECK-SAME: i8 noundef [[MASK_COERCE:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP0:%.*]], ptr addrspace(42) noundef [[PTR:%.*]]) #[[ATTR0]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[MASK:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[MASK_ADDR:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[IDX_ADDR:%.*]] = alloca <8 x i32>, align 32
// CHECK-NEXT: [[PTR_ADDR:%.*]] = alloca ptr addrspace(42), align 8
// CHECK-NEXT: store i8 [[MASK_COERCE]], ptr [[MASK]], align 1
// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[MASK]], align 1
// CHECK-NEXT: [[MASK1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
// CHECK-NEXT: [[IDX:%.*]] = load <8 x i32>, ptr [[TMP0]], align 32
// CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x i1> [[MASK1]] to i8
// CHECK-NEXT: store i8 [[TMP1]], ptr [[MASK_ADDR]], align 1
// CHECK-NEXT: store <8 x i32> [[IDX]], ptr [[IDX_ADDR]], align 32
// CHECK-NEXT: store ptr addrspace(42) [[PTR]], ptr [[PTR_ADDR]], align 8
// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[MASK_ADDR]], align 1
// CHECK-NEXT: [[TMP2:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
// CHECK-NEXT: [[TMP3:%.*]] = load <8 x i32>, ptr [[IDX_ADDR]], align 32
// CHECK-NEXT: [[TMP4:%.*]] = load ptr addrspace(42), ptr [[PTR_ADDR]], align 8
// CHECK-NEXT: [[TMP5:%.*]] = getelementptr i32, ptr addrspace(42) [[TMP4]], <8 x i32> [[TMP3]]
// CHECK-NEXT: [[MASKED_GATHER:%.*]] = call <8 x i32> @llvm.masked.gather.v8i32.v8p42(<8 x ptr addrspace(42)> [[TMP5]], i32 4, <8 x i1> [[TMP2]], <8 x i32> poison)
// CHECK-NEXT: ret <8 x i32> [[MASKED_GATHER]]
//
v8i test_gather_as(v8b mask, v8i idx, int __attribute__((address_space(42))) *ptr) {
return __builtin_masked_gather(mask, idx, ptr);
}

// CHECK-LABEL: define dso_local void @test_scatter_as(
// CHECK-SAME: i8 noundef [[MASK_COERCE:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP0:%.*]], ptr noundef byval(<8 x i32>) align 32 [[TMP1:%.*]], ptr addrspace(42) noundef [[PTR:%.*]]) #[[ATTR3]] {
// CHECK-NEXT: [[ENTRY:.*:]]
// CHECK-NEXT: [[MASK:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[MASK_ADDR:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[VAL_ADDR:%.*]] = alloca <8 x i32>, align 32
// CHECK-NEXT: [[IDX_ADDR:%.*]] = alloca <8 x i32>, align 32
// CHECK-NEXT: [[PTR_ADDR:%.*]] = alloca ptr addrspace(42), align 8
// CHECK-NEXT: store i8 [[MASK_COERCE]], ptr [[MASK]], align 1
// CHECK-NEXT: [[LOAD_BITS:%.*]] = load i8, ptr [[MASK]], align 1
// CHECK-NEXT: [[MASK1:%.*]] = bitcast i8 [[LOAD_BITS]] to <8 x i1>
// CHECK-NEXT: [[VAL:%.*]] = load <8 x i32>, ptr [[TMP0]], align 32
// CHECK-NEXT: [[IDX:%.*]] = load <8 x i32>, ptr [[TMP1]], align 32
// CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i1> [[MASK1]] to i8
// CHECK-NEXT: store i8 [[TMP2]], ptr [[MASK_ADDR]], align 1
// CHECK-NEXT: store <8 x i32> [[VAL]], ptr [[VAL_ADDR]], align 32
// CHECK-NEXT: store <8 x i32> [[IDX]], ptr [[IDX_ADDR]], align 32
// CHECK-NEXT: store ptr addrspace(42) [[PTR]], ptr [[PTR_ADDR]], align 8
// CHECK-NEXT: [[LOAD_BITS2:%.*]] = load i8, ptr [[MASK_ADDR]], align 1
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8 [[LOAD_BITS2]] to <8 x i1>
// CHECK-NEXT: [[TMP4:%.*]] = load <8 x i32>, ptr [[VAL_ADDR]], align 32
// CHECK-NEXT: [[TMP5:%.*]] = load <8 x i32>, ptr [[IDX_ADDR]], align 32
// CHECK-NEXT: [[TMP6:%.*]] = load ptr addrspace(42), ptr [[PTR_ADDR]], align 8
// CHECK-NEXT: [[TMP7:%.*]] = getelementptr i32, ptr addrspace(42) [[TMP6]], <8 x i32> [[TMP4]]
// CHECK-NEXT: call void @llvm.masked.scatter.v8i32.v8p42(<8 x i32> [[TMP5]], <8 x ptr addrspace(42)> [[TMP7]], i32 4, <8 x i1> [[TMP3]])
// CHECK-NEXT: ret void
//
void test_scatter_as(v8b mask, v8i val, v8i idx, int __attribute__((address_space(42))) *ptr) {
__builtin_masked_scatter(mask, val, idx, ptr);
}
Loading