Skip to content

Commit f9568e8

Browse files
authored
[HLSL] Make memory representation of boolean vectors in HLSL, vectors of i32. Add support for boolean swizzling. (llvm#123977)
Make the memory representation of boolean vectors in HLSL, vectors of i32. Allow boolean swizzling for boolean vectors in HLSL. Add tests for boolean vectors and boolean vector swizzling. Closes llvm#91639
1 parent dafb566 commit f9568e8

File tree

14 files changed

+352
-37
lines changed

14 files changed

+352
-37
lines changed

clang/include/clang/AST/Type.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2568,6 +2568,9 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
25682568
bool isVectorType() const; // GCC vector type.
25692569
bool isExtVectorType() const; // Extended vector type.
25702570
bool isExtVectorBoolType() const; // Extended vector type with bool element.
2571+
// Extended vector type with bool element that is packed. HLSL doesn't pack
2572+
// its bool vectors.
2573+
bool isPackedVectorBoolType(const ASTContext &ctx) const;
25712574
bool isSubscriptableVectorType() const;
25722575
bool isMatrixType() const; // Matrix type.
25732576
bool isConstantMatrixType() const; // Constant matrix type.

clang/lib/AST/ASTContext.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2005,8 +2005,9 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const {
20052005
case Type::Vector: {
20062006
const auto *VT = cast<VectorType>(T);
20072007
TypeInfo EltInfo = getTypeInfo(VT->getElementType());
2008-
Width = VT->isExtVectorBoolType() ? VT->getNumElements()
2009-
: EltInfo.Width * VT->getNumElements();
2008+
Width = VT->isPackedVectorBoolType(*this)
2009+
? VT->getNumElements()
2010+
: EltInfo.Width * VT->getNumElements();
20102011
// Enforce at least byte size and alignment.
20112012
Width = std::max<unsigned>(8, Width);
20122013
Align = std::max<unsigned>(8, Width);

clang/lib/AST/ByteCode/InterpBuiltinBitCast.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ static bool enumerateData(const Pointer &P, const Context &Ctx, Bits Offset,
9494
Bits ElemSize = Bits(Ctx.getASTContext().getTypeSize(ElemType));
9595
PrimType ElemT = *Ctx.classify(ElemType);
9696
// Special case, since the bools here are packed.
97-
bool PackedBools = FieldDesc->getType()->isExtVectorBoolType();
97+
bool PackedBools =
98+
FieldDesc->getType()->isPackedVectorBoolType(Ctx.getASTContext());
9899
unsigned NumElems = FieldDesc->getNumElems();
99100
bool Ok = true;
100101
for (unsigned I = P.getIndex(); I != NumElems; ++I) {
@@ -227,7 +228,7 @@ static bool CheckBitcastType(InterpState &S, CodePtr OpPC, QualType T,
227228
QualType EltTy = VT->getElementType();
228229
unsigned NElts = VT->getNumElements();
229230
unsigned EltSize =
230-
VT->isExtVectorBoolType() ? 1 : ASTCtx.getTypeSize(EltTy);
231+
VT->isPackedVectorBoolType(ASTCtx) ? 1 : ASTCtx.getTypeSize(EltTy);
231232

232233
if ((NElts * EltSize) % ASTCtx.getCharWidth() != 0) {
233234
// The vector's size in bits is not a multiple of the target's byte size,

clang/lib/AST/ExprConstant.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7474,7 +7474,7 @@ class APValueToBufferConverter {
74747474
QualType EltTy = VTy->getElementType();
74757475
unsigned NElts = VTy->getNumElements();
74767476

7477-
if (VTy->isExtVectorBoolType()) {
7477+
if (VTy->isPackedVectorBoolType(Info.Ctx)) {
74787478
// Special handling for OpenCL bool vectors:
74797479
// Since these vectors are stored as packed bits, but we can't write
74807480
// individual bits to the BitCastBuffer, we'll buffer all of the elements
@@ -7737,11 +7737,11 @@ class BufferToAPValueConverter {
77377737
QualType EltTy = VTy->getElementType();
77387738
unsigned NElts = VTy->getNumElements();
77397739
unsigned EltSize =
7740-
VTy->isExtVectorBoolType() ? 1 : Info.Ctx.getTypeSize(EltTy);
7740+
VTy->isPackedVectorBoolType(Info.Ctx) ? 1 : Info.Ctx.getTypeSize(EltTy);
77417741

77427742
SmallVector<APValue, 4> Elts;
77437743
Elts.reserve(NElts);
7744-
if (VTy->isExtVectorBoolType()) {
7744+
if (VTy->isPackedVectorBoolType(Info.Ctx)) {
77457745
// Special handling for OpenCL bool vectors:
77467746
// Since these vectors are stored as packed bits, but we can't read
77477747
// individual bits from the BitCastBuffer, we'll buffer all of the
@@ -7870,7 +7870,8 @@ static bool checkBitCastConstexprEligibilityType(SourceLocation Loc,
78707870
if (const auto *VTy = Ty->getAs<VectorType>()) {
78717871
QualType EltTy = VTy->getElementType();
78727872
unsigned NElts = VTy->getNumElements();
7873-
unsigned EltSize = VTy->isExtVectorBoolType() ? 1 : Ctx.getTypeSize(EltTy);
7873+
unsigned EltSize =
7874+
VTy->isPackedVectorBoolType(Ctx) ? 1 : Ctx.getTypeSize(EltTy);
78747875

78757876
if ((NElts * EltSize) % Ctx.getCharWidth() != 0) {
78767877
// The vector's size in bits is not a multiple of the target's byte size,

clang/lib/AST/Type.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,12 @@ VectorType::VectorType(TypeClass tc, QualType vecType, unsigned nElements,
410410
VectorTypeBits.NumElements = nElements;
411411
}
412412

413+
bool Type::isPackedVectorBoolType(const ASTContext &ctx) const {
414+
if (ctx.getLangOpts().HLSL)
415+
return false;
416+
return isExtVectorBoolType();
417+
}
418+
413419
BitIntType::BitIntType(bool IsUnsigned, unsigned NumBits)
414420
: Type(BitInt, QualType{}, TypeDependence::None), IsUnsigned(IsUnsigned),
415421
NumBits(NumBits) {}

clang/lib/CodeGen/CGDebugInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3308,7 +3308,7 @@ llvm::DIType *CGDebugInfo::CreateTypeDefinition(const ObjCInterfaceType *Ty,
33083308

33093309
llvm::DIType *CGDebugInfo::CreateType(const VectorType *Ty,
33103310
llvm::DIFile *Unit) {
3311-
if (Ty->isExtVectorBoolType()) {
3311+
if (Ty->isPackedVectorBoolType(CGM.getContext())) {
33123312
// Boolean ext_vector_type(N) are special because their real element type
33133313
// (bits of bit size) is not their Clang element type (_Bool of size byte).
33143314
// For now, we pretend the boolean vector were actually a vector of bytes

clang/lib/CodeGen/CGExpr.cpp

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1985,7 +1985,7 @@ llvm::Value *CodeGenFunction::EmitLoadOfScalar(Address Addr, bool Volatile,
19851985

19861986
if (const auto *ClangVecTy = Ty->getAs<VectorType>()) {
19871987
// Boolean vectors use `iN` as storage type.
1988-
if (ClangVecTy->isExtVectorBoolType()) {
1988+
if (ClangVecTy->isPackedVectorBoolType(getContext())) {
19891989
llvm::Type *ValTy = ConvertType(Ty);
19901990
unsigned ValNumElems =
19911991
cast<llvm::FixedVectorType>(ValTy)->getNumElements();
@@ -2064,6 +2064,10 @@ llvm::Value *CodeGenFunction::EmitToMemory(llvm::Value *Value, QualType Ty) {
20642064

20652065
if (Ty->isExtVectorBoolType()) {
20662066
llvm::Type *StoreTy = convertTypeForLoadStore(Ty, Value->getType());
2067+
if (StoreTy->isVectorTy() && StoreTy->getScalarSizeInBits() >
2068+
Value->getType()->getScalarSizeInBits())
2069+
return Builder.CreateZExt(Value, StoreTy);
2070+
20672071
// Expand to the memory bit width.
20682072
unsigned MemNumElems = StoreTy->getPrimitiveSizeInBits();
20692073
// <N x i1> --> <P x i1>.
@@ -2079,8 +2083,9 @@ llvm::Value *CodeGenFunction::EmitToMemory(llvm::Value *Value, QualType Ty) {
20792083
/// by convertTypeForLoadStore) to its primary IR type (as returned
20802084
/// by ConvertType).
20812085
llvm::Value *CodeGenFunction::EmitFromMemory(llvm::Value *Value, QualType Ty) {
2082-
if (Ty->isExtVectorBoolType()) {
2086+
if (Ty->isPackedVectorBoolType(getContext())) {
20832087
const auto *RawIntTy = Value->getType();
2088+
20842089
// Bitcast iP --> <P x i1>.
20852090
auto *PaddedVecTy = llvm::FixedVectorType::get(
20862091
Builder.getInt1Ty(), RawIntTy->getPrimitiveSizeInBits());
@@ -2091,10 +2096,10 @@ llvm::Value *CodeGenFunction::EmitFromMemory(llvm::Value *Value, QualType Ty) {
20912096
return emitBoolVecConversion(V, ValNumElems, "extractvec");
20922097
}
20932098

2094-
if (hasBooleanRepresentation(Ty) || Ty->isBitIntType()) {
2095-
llvm::Type *ResTy = ConvertType(Ty);
2099+
llvm::Type *ResTy = ConvertType(Ty);
2100+
if (hasBooleanRepresentation(Ty) || Ty->isBitIntType() ||
2101+
Ty->isExtVectorBoolType())
20962102
return Builder.CreateTrunc(Value, ResTy, "loadedv");
2097-
}
20982103

20992104
return Value;
21002105
}
@@ -2152,7 +2157,8 @@ void CodeGenFunction::EmitStoreOfScalar(llvm::Value *Value, Address Addr,
21522157
if (auto *VecTy = dyn_cast<llvm::FixedVectorType>(SrcTy)) {
21532158
auto *NewVecTy =
21542159
CGM.getABIInfo().getOptimalVectorMemoryType(VecTy, getLangOpts());
2155-
if (!ClangVecTy->isExtVectorBoolType() && VecTy != NewVecTy) {
2160+
if (!ClangVecTy->isPackedVectorBoolType(getContext()) &&
2161+
VecTy != NewVecTy) {
21562162
SmallVector<int, 16> Mask(NewVecTy->getNumElements(), -1);
21572163
std::iota(Mask.begin(), Mask.begin() + VecTy->getNumElements(), 0);
21582164
Value = Builder.CreateShuffleVector(Value, Mask, "extractVec");
@@ -2343,7 +2349,15 @@ RValue CodeGenFunction::EmitLoadOfExtVectorElementLValue(LValue LV) {
23432349
if (!ExprVT) {
23442350
unsigned InIdx = getAccessedFieldNo(0, Elts);
23452351
llvm::Value *Elt = llvm::ConstantInt::get(SizeTy, InIdx);
2346-
return RValue::get(Builder.CreateExtractElement(Vec, Elt));
2352+
2353+
llvm::Value *Element = Builder.CreateExtractElement(Vec, Elt);
2354+
2355+
llvm::Type *LVTy = ConvertType(LV.getType());
2356+
if (Element->getType()->getPrimitiveSizeInBits() >
2357+
LVTy->getPrimitiveSizeInBits())
2358+
Element = Builder.CreateTrunc(Element, LVTy);
2359+
2360+
return RValue::get(Element);
23472361
}
23482362

23492363
// Always use shuffle vector to try to retain the original program structure
@@ -2354,6 +2368,10 @@ RValue CodeGenFunction::EmitLoadOfExtVectorElementLValue(LValue LV) {
23542368
Mask.push_back(getAccessedFieldNo(i, Elts));
23552369

23562370
Vec = Builder.CreateShuffleVector(Vec, Mask);
2371+
2372+
if (LV.getType()->isExtVectorBoolType())
2373+
Vec = Builder.CreateTrunc(Vec, ConvertType(LV.getType()), "truncv");
2374+
23572375
return RValue::get(Vec);
23582376
}
23592377

@@ -2407,26 +2425,35 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
24072425
// Read/modify/write the vector, inserting the new element.
24082426
llvm::Value *Vec = Builder.CreateLoad(Dst.getVectorAddress(),
24092427
Dst.isVolatileQualified());
2428+
llvm::Type *VecTy = Vec->getType();
2429+
llvm::Value *SrcVal = Src.getScalarVal();
2430+
2431+
if (SrcVal->getType()->getPrimitiveSizeInBits() <
2432+
VecTy->getScalarSizeInBits())
2433+
SrcVal = Builder.CreateZExt(SrcVal, VecTy->getScalarType());
2434+
24102435
auto *IRStoreTy = dyn_cast<llvm::IntegerType>(Vec->getType());
24112436
if (IRStoreTy) {
24122437
auto *IRVecTy = llvm::FixedVectorType::get(
24132438
Builder.getInt1Ty(), IRStoreTy->getPrimitiveSizeInBits());
24142439
Vec = Builder.CreateBitCast(Vec, IRVecTy);
24152440
// iN --> <N x i1>.
24162441
}
2417-
llvm::Value *SrcVal = Src.getScalarVal();
2442+
24182443
// Allow inserting `<1 x T>` into an `<N x T>`. It can happen with scalar
24192444
// types which are mapped to vector LLVM IR types (e.g. for implementing
24202445
// an ABI).
24212446
if (auto *EltTy = dyn_cast<llvm::FixedVectorType>(SrcVal->getType());
24222447
EltTy && EltTy->getNumElements() == 1)
24232448
SrcVal = Builder.CreateBitCast(SrcVal, EltTy->getElementType());
2449+
24242450
Vec = Builder.CreateInsertElement(Vec, SrcVal, Dst.getVectorIdx(),
24252451
"vecins");
24262452
if (IRStoreTy) {
24272453
// <N x i1> --> <iN>.
24282454
Vec = Builder.CreateBitCast(Vec, IRStoreTy);
24292455
}
2456+
24302457
Builder.CreateStore(Vec, Dst.getVectorAddress(),
24312458
Dst.isVolatileQualified());
24322459
return;
@@ -2623,14 +2650,12 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
26232650
// This access turns into a read/modify/write of the vector. Load the input
26242651
// value now.
26252652
llvm::Value *Vec = Builder.CreateLoad(DstAddr, Dst.isVolatileQualified());
2653+
llvm::Type *VecTy = Vec->getType();
26262654
const llvm::Constant *Elts = Dst.getExtVectorElts();
26272655

2628-
llvm::Value *SrcVal = Src.getScalarVal();
2629-
26302656
if (const VectorType *VTy = Dst.getType()->getAs<VectorType>()) {
26312657
unsigned NumSrcElts = VTy->getNumElements();
2632-
unsigned NumDstElts =
2633-
cast<llvm::FixedVectorType>(Vec->getType())->getNumElements();
2658+
unsigned NumDstElts = cast<llvm::FixedVectorType>(VecTy)->getNumElements();
26342659
if (NumDstElts == NumSrcElts) {
26352660
// Use shuffle vector is the src and destination are the same number of
26362661
// elements and restore the vector mask since it is on the side it will be
@@ -2639,6 +2664,11 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
26392664
for (unsigned i = 0; i != NumSrcElts; ++i)
26402665
Mask[getAccessedFieldNo(i, Elts)] = i;
26412666

2667+
llvm::Value *SrcVal = Src.getScalarVal();
2668+
if (VecTy->getScalarSizeInBits() >
2669+
SrcVal->getType()->getScalarSizeInBits())
2670+
SrcVal = Builder.CreateZExt(SrcVal, VecTy);
2671+
26422672
Vec = Builder.CreateShuffleVector(SrcVal, Mask);
26432673
} else if (NumDstElts > NumSrcElts) {
26442674
// Extended the source vector to the same length and then shuffle it
@@ -2649,7 +2679,8 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
26492679
for (unsigned i = 0; i != NumSrcElts; ++i)
26502680
ExtMask.push_back(i);
26512681
ExtMask.resize(NumDstElts, -1);
2652-
llvm::Value *ExtSrcVal = Builder.CreateShuffleVector(SrcVal, ExtMask);
2682+
llvm::Value *ExtSrcVal =
2683+
Builder.CreateShuffleVector(Src.getScalarVal(), ExtMask);
26532684
// build identity
26542685
SmallVector<int, 4> Mask;
26552686
for (unsigned i = 0; i != NumDstElts; ++i)
@@ -2674,6 +2705,11 @@ void CodeGenFunction::EmitStoreThroughExtVectorComponentLValue(RValue Src,
26742705
// be updating one element.
26752706
unsigned InIdx = getAccessedFieldNo(0, Elts);
26762707
llvm::Value *Elt = llvm::ConstantInt::get(SizeTy, InIdx);
2708+
2709+
llvm::Value *SrcVal = Src.getScalarVal();
2710+
if (VecTy->getScalarSizeInBits() > SrcVal->getType()->getScalarSizeInBits())
2711+
SrcVal = Builder.CreateZExt(SrcVal, VecTy->getScalarType());
2712+
26772713
Vec = Builder.CreateInsertElement(Vec, SrcVal, Elt);
26782714
}
26792715

@@ -4701,9 +4737,13 @@ EmitExtVectorElementExpr(const ExtVectorElementExpr *E) {
47014737

47024738
// Store the vector to memory (because LValue wants an address).
47034739
Address VecMem = CreateMemTemp(E->getBase()->getType());
4740+
// need to zero extend an hlsl boolean vector to store it back to memory
4741+
QualType Ty = E->getBase()->getType();
4742+
llvm::Type *LTy = convertTypeForLoadStore(Ty, Vec->getType());
4743+
if (LTy->getScalarSizeInBits() > Vec->getType()->getScalarSizeInBits())
4744+
Vec = Builder.CreateZExt(Vec, LTy);
47044745
Builder.CreateStore(Vec, VecMem);
4705-
Base = MakeAddrLValue(VecMem, E->getBase()->getType(),
4706-
AlignmentSource::Decl);
4746+
Base = MakeAddrLValue(VecMem, Ty, AlignmentSource::Decl);
47074747
}
47084748

47094749
QualType type =

clang/lib/CodeGen/CGExprConstant.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1981,7 +1981,10 @@ llvm::Constant *ConstantEmitter::emitForMemory(CodeGenModule &CGM,
19811981
}
19821982

19831983
// Zero-extend bool.
1984-
if (C->getType()->isIntegerTy(1) && !destType->isBitIntType()) {
1984+
// In HLSL bool vectors are stored in memory as a vector of i32
1985+
if ((C->getType()->isIntegerTy(1) && !destType->isBitIntType()) ||
1986+
(destType->isExtVectorBoolType() &&
1987+
!destType->isPackedVectorBoolType(CGM.getContext()))) {
19851988
llvm::Type *boolTy = CGM.getTypes().ConvertTypeForMem(destType);
19861989
llvm::Constant *Res = llvm::ConstantFoldCastOperand(
19871990
llvm::Instruction::ZExt, C, boolTy, CGM.getDataLayout());

clang/lib/CodeGen/CodeGenTypes.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,12 @@ llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T) {
115115
// Check for the boolean vector case.
116116
if (T->isExtVectorBoolType()) {
117117
auto *FixedVT = cast<llvm::FixedVectorType>(R);
118+
119+
if (Context.getLangOpts().HLSL) {
120+
llvm::Type *IRElemTy = ConvertTypeForMem(Context.BoolTy);
121+
return llvm::FixedVectorType::get(IRElemTy, FixedVT->getNumElements());
122+
}
123+
118124
// Pad to at least one byte.
119125
uint64_t BytePadded = std::max<uint64_t>(FixedVT->getNumElements(), 8);
120126
return llvm::IntegerType::get(FixedVT->getContext(), BytePadded);
@@ -657,7 +663,7 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
657663
case Type::Vector: {
658664
const auto *VT = cast<VectorType>(Ty);
659665
// An ext_vector_type of Bool is really a vector of bits.
660-
llvm::Type *IRElemTy = VT->isExtVectorBoolType()
666+
llvm::Type *IRElemTy = VT->isPackedVectorBoolType(Context)
661667
? llvm::Type::getInt1Ty(getLLVMContext())
662668
: VT->getElementType()->isMFloat8Type()
663669
? llvm::Type::getInt8Ty(getLLVMContext())

clang/lib/Sema/SemaExprMember.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1697,7 +1697,7 @@ static ExprResult LookupMemberExpr(Sema &S, LookupResult &R,
16971697
QualType(), false);
16981698
}
16991699

1700-
if (BaseType->isExtVectorBoolType()) {
1700+
if (BaseType->isPackedVectorBoolType(S.Context)) {
17011701
// We disallow element access for ext_vector_type bool. There is no way to
17021702
// materialize a reference to a vector element as a pointer (each element is
17031703
// one bit in the vector).

0 commit comments

Comments
 (0)