Skip to content

Commit dda80a9

Browse files
authored
[SPIR-V] Emulate OpBitFieldExtract for type != i32 (microsoft#6500)
This is a follow up to microsoft#6491 which adds emulation for OpBitField*Extract This is required because Vulkan requires OpBitField*Extract operands to be 32-bit integers. Fixes microsoft#6327 --------- Signed-off-by: Nathan Gauër <[email protected]>
1 parent 1e59ce9 commit dda80a9

File tree

6 files changed

+175
-47
lines changed

6 files changed

+175
-47
lines changed

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -451,11 +451,10 @@ class SpirvBuilder {
451451

452452
/// \brief Creates an OpBitFieldUExtract or OpBitFieldSExtract SPIR-V
453453
/// instruction for the given arguments.
454-
SpirvBitFieldExtract *createBitFieldExtract(QualType resultType,
455-
SpirvInstruction *base,
456-
SpirvInstruction *offset,
457-
SpirvInstruction *count,
458-
bool isSigned, SourceLocation);
454+
SpirvInstruction *createBitFieldExtract(QualType resultType,
455+
SpirvInstruction *base,
456+
unsigned bitOffset, unsigned bitCount,
457+
SourceLocation, SourceRange);
459458

460459
/// \brief Creates an OpEmitVertex instruction.
461460
void createEmitVertex(SourceLocation, SourceRange range = {});
@@ -840,6 +839,12 @@ class SpirvBuilder {
840839
unsigned bitOffset, unsigned bitCount,
841840
SourceLocation, SourceRange);
842841

842+
SpirvInstruction *
843+
createEmulatedBitFieldExtract(QualType resultType, uint32_t baseTypeBitwidth,
844+
SpirvInstruction *base, unsigned bitOffset,
845+
unsigned bitCount, SourceLocation loc,
846+
SourceRange range);
847+
843848
private:
844849
ASTContext &astContext;
845850
SpirvContext &context; ///< From which we allocate various SPIR-V object

tools/clang/include/clang/SPIRV/SpirvInstruction.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1141,7 +1141,7 @@ class SpirvBitFieldExtract : public SpirvBitField {
11411141
public:
11421142
SpirvBitFieldExtract(QualType resultType, SourceLocation loc,
11431143
SpirvInstruction *base, SpirvInstruction *offset,
1144-
SpirvInstruction *count, bool isSigned);
1144+
SpirvInstruction *count);
11451145

11461146
DEFINE_RELEASE_MEMORY_FOR_CLASS(SpirvBitFieldExtract)
11471147

@@ -1151,10 +1151,6 @@ class SpirvBitFieldExtract : public SpirvBitField {
11511151
}
11521152

11531153
bool invokeVisitor(Visitor *v) override;
1154-
1155-
uint32_t isSigned() const {
1156-
return getopcode() == spv::Op::OpBitFieldSExtract;
1157-
}
11581154
};
11591155

11601156
class SpirvBitFieldInsert : public SpirvBitField {

tools/clang/lib/SPIRV/SpirvBuilder.cpp

Lines changed: 67 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -226,17 +226,9 @@ SpirvInstruction *SpirvBuilder::createLoad(QualType resultType,
226226
if (!bitfieldInfo.hasValue())
227227
return instruction;
228228

229-
auto *offset = getConstantInt(
230-
astContext.UnsignedIntTy,
231-
llvm::APInt(32, static_cast<uint64_t>(bitfieldInfo->offsetInBits),
232-
/* isSigned= */ false));
233-
auto *count = getConstantInt(
234-
astContext.UnsignedIntTy,
235-
llvm::APInt(32, static_cast<uint64_t>(bitfieldInfo->sizeInBits),
236-
/* isSigned= */ false));
237-
return createBitFieldExtract(
238-
resultType, instruction, offset, count,
239-
pointer->getAstResultType()->isSignedIntegerOrEnumerationType(), loc);
229+
return createBitFieldExtract(resultType, instruction,
230+
bitfieldInfo->offsetInBits,
231+
bitfieldInfo->sizeInBits, loc, range);
240232
}
241233

242234
SpirvCopyObject *SpirvBuilder::createCopyObject(QualType resultType,
@@ -980,12 +972,71 @@ SpirvBuilder::createBitFieldInsert(QualType resultType, SpirvInstruction *base,
980972
return inst;
981973
}
982974

983-
SpirvBitFieldExtract *SpirvBuilder::createBitFieldExtract(
984-
QualType resultType, SpirvInstruction *base, SpirvInstruction *offset,
985-
SpirvInstruction *count, bool isSigned, SourceLocation loc) {
975+
SpirvInstruction *SpirvBuilder::createEmulatedBitFieldExtract(
976+
QualType resultType, uint32_t baseTypeBitwidth, SpirvInstruction *base,
977+
unsigned bitOffset, unsigned bitCount, SourceLocation loc,
978+
SourceRange range) {
979+
assert(bitCount <= 64 &&
980+
"Bitfield extraction emulation can only extract at most 64 bits.");
981+
982+
// The base is a raw struct field, which can contain several bitfields:
983+
// raw field: AAAABBBBCCCCCCCCDDDD
984+
// Extracting B means shifting it right until B's LSB is the basetype LSB.
985+
// But first, we need to left shift until B's MSB becomes the basetype MSB:
986+
// - is B is signed, its sign bits won't necessarily extend up to the
987+
// basetype MSB.
988+
// - meaning a right-shift could fail to sign-extend.
989+
// - shifting left first, then right makes sure the sign extension happens.
990+
991+
// input: AAAABBBBCCCCCCCCDDDD
992+
// output: BBBBCCCCCCCCDDDD0000
993+
auto *leftShiftOffset =
994+
getConstantInt(astContext.UnsignedIntTy,
995+
llvm::APInt(32, baseTypeBitwidth - bitOffset - bitCount));
996+
auto *leftShift = createBinaryOp(spv::Op::OpShiftLeftLogical, resultType,
997+
base, leftShiftOffset, loc, range);
998+
999+
// input: BBBBCCCCCCCCDDDD0000
1000+
// output: SSSSSSSSSSSSSSSSBBBB
1001+
auto *rightShiftOffset = getConstantInt(
1002+
astContext.UnsignedIntTy, llvm::APInt(32, baseTypeBitwidth - bitCount));
1003+
auto *rightShift = createBinaryOp(spv::Op::OpShiftRightArithmetic, resultType,
1004+
leftShift, rightShiftOffset, loc, range);
1005+
1006+
if (resultType == QualType({})) {
1007+
auto baseType = dyn_cast<IntegerType>(base->getResultType());
1008+
leftShift->setResultType(baseType);
1009+
rightShift->setResultType(baseType);
1010+
}
1011+
1012+
return rightShift;
1013+
}
1014+
1015+
SpirvInstruction *
1016+
SpirvBuilder::createBitFieldExtract(QualType resultType, SpirvInstruction *base,
1017+
unsigned bitOffset, unsigned bitCount,
1018+
SourceLocation loc, SourceRange range) {
9861019
assert(insertPoint && "null insert point");
987-
auto *inst = new (context)
988-
SpirvBitFieldExtract(resultType, loc, base, offset, count, isSigned);
1020+
1021+
uint32_t bitWidth = 0;
1022+
if (resultType == QualType({})) {
1023+
assert(base->hasResultType() && "No type information for bitfield.");
1024+
bitWidth = dyn_cast<IntegerType>(base->getResultType())->getBitwidth();
1025+
} else {
1026+
bitWidth = getElementSpirvBitwidth(astContext, resultType,
1027+
spirvOptions.enable16BitTypes);
1028+
}
1029+
1030+
if (bitWidth != 32)
1031+
return createEmulatedBitFieldExtract(resultType, bitWidth, base, bitOffset,
1032+
bitCount, loc, range);
1033+
1034+
auto *offset =
1035+
getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, bitOffset));
1036+
auto *count =
1037+
getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, bitCount));
1038+
auto *inst =
1039+
new (context) SpirvBitFieldExtract(resultType, loc, base, offset, count);
9891040
insertPoint->addInstruction(inst);
9901041
inst->setRValue(true);
9911042
return inst;

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8298,14 +8298,9 @@ SpirvInstruction *SpirvEmitter::castToInt(SpirvInstruction *fromVal,
82988298
firstField = spvBuilder.createCompositeExtract(fieldType, fromVal, {0},
82998299
srcLoc, srcRange);
83008300
if (fieldDecl->isBitField()) {
8301-
auto offset = spvBuilder.getConstantInt(astContext.UnsignedIntTy,
8302-
llvm::APInt(32, 0));
8303-
auto width = spvBuilder.getConstantInt(
8304-
astContext.UnsignedIntTy,
8305-
llvm::APInt(32, fieldDecl->getBitWidthValue(astContext)));
83068301
firstField = spvBuilder.createBitFieldExtract(
8307-
fieldType, firstField, offset, width,
8308-
toIntType->hasSignedIntegerRepresentation(), srcLoc);
8302+
fieldType, firstField, 0, fieldDecl->getBitWidthValue(astContext),
8303+
srcLoc, srcRange);
83098304
}
83108305
}
83118306

@@ -9367,10 +9362,7 @@ SpirvEmitter::processIntrinsicMsad4(const CallExpr *callExpr) {
93679362
llvm::SmallVector<SpirvInstruction *, 4> isRefByteZero;
93689363
for (uint32_t i = 0; i < 4; ++i) {
93699364
refBytes.push_back(spvBuilder.createBitFieldExtract(
9370-
uintType, reference, /*offset*/
9371-
spvBuilder.getConstantInt(astContext.UnsignedIntTy,
9372-
llvm::APInt(32, i * 8)),
9373-
/*count*/ uint8, /*isSigned*/ false, loc));
9365+
uintType, reference, /*offset*/ i * 8, /*count*/ 8, loc, range));
93749366
signedRefBytes.push_back(spvBuilder.createUnaryOp(
93759367
spv::Op::OpBitcast, intType, refBytes.back(), loc));
93769368
isRefByteZero.push_back(spvBuilder.createBinaryOp(
@@ -9381,11 +9373,8 @@ SpirvEmitter::processIntrinsicMsad4(const CallExpr *callExpr) {
93819373
for (uint32_t byteCount = 0; byteCount < 4; ++byteCount) {
93829374
// 'count' is always 8 because we are extracting 8 bits out of 32.
93839375
auto *srcByte = spvBuilder.createBitFieldExtract(
9384-
uintType, sources[msadNum],
9385-
/*offset*/
9386-
spvBuilder.getConstantInt(astContext.UnsignedIntTy,
9387-
llvm::APInt(32, 8 * byteCount)),
9388-
/*count*/ uint8, /*isSigned*/ false, loc);
9376+
uintType, sources[msadNum], /*offset*/ 8 * byteCount, /*count*/ 8,
9377+
loc, range);
93899378
auto *signedSrcByte =
93909379
spvBuilder.createUnaryOp(spv::Op::OpBitcast, intType, srcByte, loc);
93919380
auto *sub = spvBuilder.createBinaryOp(spv::Op::OpISub, intType,

tools/clang/lib/SPIRV/SpirvInstruction.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -474,12 +474,15 @@ SpirvBitField::SpirvBitField(Kind kind, spv::Op op, QualType resultType,
474474
: SpirvInstruction(kind, op, resultType, loc), base(baseInst),
475475
offset(offsetInst), count(countInst) {}
476476

477-
SpirvBitFieldExtract::SpirvBitFieldExtract(
478-
QualType resultType, SourceLocation loc, SpirvInstruction *baseInst,
479-
SpirvInstruction *offsetInst, SpirvInstruction *countInst, bool isSigned)
477+
SpirvBitFieldExtract::SpirvBitFieldExtract(QualType resultType,
478+
SourceLocation loc,
479+
SpirvInstruction *baseInst,
480+
SpirvInstruction *offsetInst,
481+
SpirvInstruction *countInst)
480482
: SpirvBitField(IK_BitFieldExtract,
481-
isSigned ? spv::Op::OpBitFieldSExtract
482-
: spv::Op::OpBitFieldUExtract,
483+
resultType->isSignedIntegerOrEnumerationType()
484+
? spv::Op::OpBitFieldSExtract
485+
: spv::Op::OpBitFieldUExtract,
483486
resultType, loc, baseInst, offsetInst, countInst) {}
484487

485488
SpirvBitFieldInsert::SpirvBitFieldInsert(QualType resultType,
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// RUN: %dxc -T cs_6_2 -E main -spirv -fcgl -enable-16bit-types %s | FileCheck %s
2+
3+
struct S1 {
4+
uint64_t f1 : 32;
5+
uint64_t f2 : 1;
6+
};
7+
// CHECK-DAG: %S1 = OpTypeStruct %ulong
8+
9+
struct S2 {
10+
int64_t f1 : 32;
11+
int64_t f2 : 1;
12+
};
13+
// CHECK-DAG: %S2 = OpTypeStruct %long
14+
15+
struct S3 {
16+
uint64_t f1 : 45;
17+
uint64_t f2 : 10;
18+
uint16_t f3 : 7;
19+
uint32_t f4 : 5;
20+
};
21+
// CHECK-DAG: %S3 = OpTypeStruct %ulong %ushort %uint
22+
23+
[numthreads(1, 1, 1)]
24+
void main() {
25+
uint64_t vulong;
26+
uint32_t vuint;
27+
uint16_t vushort;
28+
int64_t vlong;
29+
30+
S1 s1;
31+
vulong = s1.f1;
32+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s1 %int_0
33+
// CHECK: [[raw:%[0-9]+]] = OpLoad %ulong [[ptr]]
34+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[raw]] %uint_32
35+
// CHECK: [[out:%[0-9]+]] = OpShiftRightArithmetic %ulong [[tmp]] %uint_32
36+
// CHECK: OpStore %vulong [[out]]
37+
vulong = s1.f2;
38+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s1 %int_0
39+
// CHECK: [[raw:%[0-9]+]] = OpLoad %ulong [[ptr]]
40+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[raw]] %uint_31
41+
// CHECK: [[out:%[0-9]+]] = OpShiftRightArithmetic %ulong [[tmp]] %uint_63
42+
// CHECK: OpStore %vulong [[out]]
43+
44+
S2 s2;
45+
vlong = s2.f1;
46+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_long %s2 %int_0
47+
// CHECK: [[raw:%[0-9]+]] = OpLoad %long [[ptr]]
48+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %long [[raw]] %uint_32
49+
// CHECK: [[out:%[0-9]+]] = OpShiftRightArithmetic %long [[tmp]] %uint_32
50+
// CHECK: OpStore %vlong [[out]]
51+
vlong = s2.f2;
52+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_long %s2 %int_0
53+
// CHECK: [[raw:%[0-9]+]] = OpLoad %long [[ptr]]
54+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %long [[raw]] %uint_31
55+
// CHECK: [[out:%[0-9]+]] = OpShiftRightArithmetic %long [[tmp]] %uint_63
56+
// CHECK: OpStore %vlong [[out]]
57+
58+
S3 s3;
59+
vulong = s3.f1;
60+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s3 %int_0
61+
// CHECK: [[raw:%[0-9]+]] = OpLoad %ulong [[ptr]]
62+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[raw]] %uint_19
63+
// CHECK: [[out:%[0-9]+]] = OpShiftRightArithmetic %ulong [[tmp]] %uint_19
64+
// CHECK: OpStore %vulong [[out]]
65+
vulong = s3.f2;
66+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s3 %int_0
67+
// CHECK: [[raw:%[0-9]+]] = OpLoad %ulong [[ptr]]
68+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[raw]] %uint_9
69+
// CHECK: [[out:%[0-9]+]] = OpShiftRightArithmetic %ulong [[tmp]] %uint_54
70+
// CHECK: OpStore %vulong [[out]]
71+
72+
vushort = s3.f3;
73+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ushort %s3 %int_1
74+
// CHECK: [[raw:%[0-9]+]] = OpLoad %ushort [[ptr]]
75+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ushort [[raw]] %uint_9
76+
// CHECK: [[out:%[0-9]+]] = OpShiftRightArithmetic %ushort [[tmp]] %uint_9
77+
// CHECK: OpStore %vushort [[out]]
78+
79+
vuint = s3.f4;
80+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_uint %s3 %int_2
81+
// CHECK: [[raw:%[0-9]+]] = OpLoad %uint [[ptr]]
82+
// CHECK: [[tmp:%[0-9]+]] = OpBitFieldUExtract %uint [[raw]] %uint_0 %uint_5
83+
// CHECK: OpStore %vuint [[tmp]]
84+
}

0 commit comments

Comments
 (0)