Skip to content

Commit 0781ded

Browse files
authored
[SPIR-V] Emulate OpBitFieldInsert for type != i32 (microsoft#6491)
SPIR-V supported OpBitFieldInsert on all integer types. But Vulkan requires the operands to be 32bit integers. This means we need to emulate this instruction for types other then i32. This PR only adds emulation for OpBitFieldInsert. The next PR will add support for emulating OpBitFieldExtract, Related to microsoft#6327 --------- Signed-off-by: Nathan Gauër <[email protected]>
1 parent 8b5b6c6 commit 0781ded

File tree

5 files changed

+275
-38
lines changed

5 files changed

+275
-38
lines changed

tools/clang/include/clang/SPIRV/SpirvBuilder.h

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,10 +443,11 @@ class SpirvBuilder {
443443

444444
/// \brief Creates an OpBitFieldInsert SPIR-V instruction for the given
445445
/// arguments.
446-
SpirvBitFieldInsert *
447-
createBitFieldInsert(QualType resultType, SpirvInstruction *base,
448-
SpirvInstruction *insert, SpirvInstruction *offset,
449-
SpirvInstruction *count, SourceLocation);
446+
SpirvInstruction *createBitFieldInsert(QualType resultType,
447+
SpirvInstruction *base,
448+
SpirvInstruction *insert,
449+
unsigned bitOffset, unsigned bitCount,
450+
SourceLocation, SourceRange);
450451

451452
/// \brief Creates an OpBitFieldUExtract or OpBitFieldSExtract SPIR-V
452453
/// instruction for the given arguments.
@@ -831,6 +832,14 @@ class SpirvBuilder {
831832
const SpirvType *spvType,
832833
SpirvInstruction *var);
833834

835+
/// \brief Emulates OpBitFieldInsert SPIR-V instruction for the given
836+
/// arguments.
837+
SpirvInstruction *
838+
createEmulatedBitFieldInsert(QualType resultType, uint32_t baseTypeBitwidth,
839+
SpirvInstruction *base, SpirvInstruction *insert,
840+
unsigned bitOffset, unsigned bitCount,
841+
SourceLocation, SourceRange);
842+
834843
private:
835844
ASTContext &astContext;
836845
SpirvContext &context; ///< From which we allocate various SPIR-V object

tools/clang/lib/SPIRV/InitListHandler.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,9 @@ InitListHandler::createInitForStructType(QualType type, SourceLocation srcLoc,
452452
// For the remaining bitfields, we need to insert them into the existing
453453
// container, which is the last element in `fields`.
454454
assert(fields.size() == fieldInfo.fieldIndex + 1);
455-
SpirvInstruction *offset = spvBuilder.getConstantInt(
456-
astContext.UnsignedIntTy,
457-
llvm::APInt(32, fieldInfo.bitfield->offsetInBits));
458-
SpirvInstruction *count = spvBuilder.getConstantInt(
459-
astContext.UnsignedIntTy,
460-
llvm::APInt(32, fieldInfo.bitfield->sizeInBits));
461455
fields.back() = spvBuilder.createBitFieldInsert(
462-
fieldType, fields.back(), init, offset, count, srcLoc);
456+
fieldType, fields.back(), init, fieldInfo.bitfield->offsetInBits,
457+
fieldInfo.bitfield->sizeInBits, srcLoc, range);
463458
return true;
464459
},
465460
true);

tools/clang/lib/SPIRV/SpirvBuilder.cpp

Lines changed: 102 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -298,16 +298,9 @@ SpirvStore *SpirvBuilder::createStore(SpirvInstruction *address,
298298
context.addToInstructionsWithLoweredType(value);
299299

300300
auto *base = createLoad(value->getResultType(), address, loc, range);
301-
auto *offset = getConstantInt(
302-
astContext.UnsignedIntTy,
303-
llvm::APInt(32, static_cast<uint64_t>(bitfieldInfo->offsetInBits),
304-
false));
305-
auto *count = getConstantInt(
306-
astContext.UnsignedIntTy,
307-
llvm::APInt(32, static_cast<uint64_t>(bitfieldInfo->sizeInBits),
308-
false));
309-
source =
310-
createBitFieldInsert(/*QualType*/ {}, base, value, offset, count, loc);
301+
source = createBitFieldInsert(/*QualType*/ {}, base, value,
302+
bitfieldInfo->offsetInBits,
303+
bitfieldInfo->sizeInBits, loc, range);
311304
source->setResultType(value->getResultType());
312305
}
313306

@@ -882,12 +875,106 @@ void SpirvBuilder::createBarrier(spv::Scope memoryScope,
882875
insertPoint->addInstruction(barrier);
883876
}
884877

885-
SpirvBitFieldInsert *SpirvBuilder::createBitFieldInsert(
886-
QualType resultType, SpirvInstruction *base, SpirvInstruction *insert,
887-
SpirvInstruction *offset, SpirvInstruction *count, SourceLocation loc) {
878+
SpirvInstruction *SpirvBuilder::createEmulatedBitFieldInsert(
879+
QualType resultType, uint32_t baseTypeBitwidth, SpirvInstruction *base,
880+
SpirvInstruction *insert, unsigned bitOffset, unsigned bitCount,
881+
SourceLocation loc, SourceRange range) {
882+
883+
// The destination is a raw struct field, which can contain several bitfields:
884+
// raw field: AAAABBBBCCCCCCCCDDDD
885+
// To insert a new value for the field BBBB, we need to clear the B bits in
886+
// the field, and insert the new values.
887+
888+
// Create a mask to clear B from the raw field.
889+
// mask = (1 << bitCount) - 1
890+
// raw field: AAAABBBBCCCCCCCCDDDD
891+
// mask: 00000000000000001111
892+
// cast mask to the an unsigned with the same bitwidth.
893+
// mask = (unsigned dstType)mask
894+
// Move the mask to B's position in the raw type.
895+
// mask = mask << bitOffset
896+
// raw field: AAAABBBBCCCCCCCCDDDD
897+
// mask: 00001111000000000000
898+
// Generate inverted mask to clear other bits in *insert*.
899+
// notMask = ~mask
900+
// raw field: AAAABBBBCCCCCCCCDDDD
901+
// mask: 11110000111111111111
902+
assert(bitCount <= 64 &&
903+
"Bitfield insertion emulation can only insert at most 64 bits.");
904+
auto maskTy =
905+
astContext.getIntTypeForBitwidth(baseTypeBitwidth, /* signed= */ 0);
906+
const uint64_t maskValue = ((1ull << bitCount) - 1ull) << bitOffset;
907+
const uint64_t notMaskValue = ~maskValue;
908+
909+
auto *mask = getConstantInt(maskTy, llvm::APInt(baseTypeBitwidth, maskValue));
910+
auto *notMask =
911+
getConstantInt(maskTy, llvm::APInt(baseTypeBitwidth, notMaskValue));
912+
auto *shiftOffset =
913+
getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, bitOffset));
914+
915+
// base = base & MASK // Clear bits at B's position.
916+
// input: AAAABBBBCCCCCCCCDDDD
917+
// output: AAAA----CCCCCCCCDDDD
918+
auto *clearedDst = createBinaryOp(spv::Op::OpBitwiseAnd, resultType, base,
919+
notMask, loc, range);
920+
921+
// input: SSSSSSSSSSSSSSSSBBBB
922+
// tmp = (dstType)SRC // Convert SRC to the base type.
923+
// tmp = tmp << bitOffset // Move the SRC value to the correct bit offset.
924+
// output: SSSSBBBB------------
925+
// tmp = tmp & ~MASK // Clear any sign extension bits.
926+
// output: ----BBBB------------
927+
auto *castedSrc =
928+
createUnaryOp(spv::Op::OpBitcast, resultType, insert, loc, range);
929+
auto *shiftedSrc = createBinaryOp(spv::Op::OpShiftLeftLogical, resultType,
930+
castedSrc, shiftOffset, loc, range);
931+
auto *maskedSrc = createBinaryOp(spv::Op::OpBitwiseAnd, resultType,
932+
shiftedSrc, mask, loc, range);
933+
934+
// base = base | tmp; // Insert B in the raw field.
935+
// tmp: ----BBBB------------
936+
// base: AAAA----CCCCCCCCDDDD
937+
// output: AAAABBBBCCCCCCCCDDDD
938+
auto *result = createBinaryOp(spv::Op::OpBitwiseOr, resultType, clearedDst,
939+
maskedSrc, loc, range);
940+
941+
if (base->getResultType()) {
942+
auto *dstTy = dyn_cast<IntegerType>(base->getResultType());
943+
clearedDst->setResultType(dstTy);
944+
shiftedSrc->setResultType(dstTy);
945+
maskedSrc->setResultType(dstTy);
946+
castedSrc->setResultType(dstTy);
947+
result->setResultType(dstTy);
948+
}
949+
return result;
950+
}
951+
952+
SpirvInstruction *
953+
SpirvBuilder::createBitFieldInsert(QualType resultType, SpirvInstruction *base,
954+
SpirvInstruction *insert, unsigned bitOffset,
955+
unsigned bitCount, SourceLocation loc,
956+
SourceRange range) {
888957
assert(insertPoint && "null insert point");
889-
auto *inst = new (context)
890-
SpirvBitFieldInsert(resultType, loc, base, insert, offset, count);
958+
959+
uint32_t bitwidth = 0;
960+
if (resultType == QualType({})) {
961+
assert(base->hasResultType() && "No type information for bitfield.");
962+
bitwidth = dyn_cast<IntegerType>(base->getResultType())->getBitwidth();
963+
} else {
964+
bitwidth = getElementSpirvBitwidth(astContext, resultType,
965+
spirvOptions.enable16BitTypes);
966+
}
967+
968+
if (bitwidth != 32)
969+
return createEmulatedBitFieldInsert(resultType, bitwidth, base, insert,
970+
bitOffset, bitCount, loc, range);
971+
972+
auto *insertOffset =
973+
getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, bitOffset));
974+
auto *insertCount =
975+
getConstantInt(astContext.UnsignedIntTy, llvm::APInt(32, bitCount));
976+
auto *inst = new (context) SpirvBitFieldInsert(resultType, loc, base, insert,
977+
insertOffset, insertCount);
891978
insertPoint->addInstruction(inst);
892979
inst->setRValue(true);
893980
return inst;

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9237,6 +9237,7 @@ SpirvEmitter::processIntrinsicNonUniformResourceIndex(const CallExpr *expr) {
92379237
SpirvInstruction *
92389238
SpirvEmitter::processIntrinsicMsad4(const CallExpr *callExpr) {
92399239
const auto loc = callExpr->getExprLoc();
9240+
const auto range = callExpr->getSourceRange();
92409241
if (!spirvOptions.noWarnEmulatedFeatures)
92419242
emitWarning("msad4 intrinsic function is emulated using many SPIR-V "
92429243
"instructions due to lack of direct SPIR-V equivalent",
@@ -9297,18 +9298,15 @@ SpirvEmitter::processIntrinsicMsad4(const CallExpr *callExpr) {
92979298
// Do bfi 3 times. DXIL bfi is equivalent to SPIR-V OpBitFieldInsert.
92989299
auto *v1y = spvBuilder.createCompositeExtract(uintType, source, {1}, loc);
92999300
// Note that t0.x = v1.x, nothing we need to do for that.
9300-
auto *t0y =
9301-
spvBuilder.createBitFieldInsert(uintType, /*base*/ v1xS8, /*insert*/ v1y,
9302-
/*offset*/ uint24,
9303-
/*width*/ uint8, loc);
9304-
auto *t0z =
9305-
spvBuilder.createBitFieldInsert(uintType, /*base*/ v1xS16, /*insert*/ v1y,
9306-
/*offset*/ uint16,
9307-
/*width*/ uint16, loc);
9308-
auto *t0w =
9309-
spvBuilder.createBitFieldInsert(uintType, /*base*/ v1xS24, /*insert*/ v1y,
9310-
/*offset*/ uint8,
9311-
/*width*/ uint24, loc);
9301+
auto *t0y = spvBuilder.createBitFieldInsert(
9302+
uintType, /*base*/ v1xS8, /*insert*/ v1y,
9303+
/* bitOffest */ 24, /* bitCount */ 8, loc, range);
9304+
auto *t0z = spvBuilder.createBitFieldInsert(
9305+
uintType, /*base*/ v1xS16, /*insert*/ v1y,
9306+
/* bitOffest */ 16, /* bitCount */ 16, loc, range);
9307+
auto *t0w = spvBuilder.createBitFieldInsert(
9308+
uintType, /*base*/ v1xS24, /*insert*/ v1y,
9309+
/* bitOffest */ 8, /* bitCount */ 24, loc, range);
93129310

93139311
// Step 3. MSAD (Masked Sum of Absolute Differences)
93149312

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// RUN: %dxc -T cs_6_2 -E main -spirv -fcgl -enable-16bit-types %s | FileCheck %s
2+
3+
struct S1 {
4+
uint64_t f1 : 32;
5+
uint64_t f2 : 1;
6+
};
7+
// CHECK-DAG: %S1 = OpTypeStruct %ulong
8+
9+
struct S2 {
10+
uint16_t f1 : 4;
11+
uint16_t f2 : 5;
12+
};
13+
// CHECK-DAG: %S2 = OpTypeStruct %ushort
14+
15+
struct S3 {
16+
uint64_t f1 : 45;
17+
uint64_t f2 : 10;
18+
uint16_t f3 : 7;
19+
uint32_t f4 : 5;
20+
};
21+
// CHECK-DAG: %S3 = OpTypeStruct %ulong %ushort %uint
22+
23+
struct S4 {
24+
int64_t f1 : 32;
25+
int64_t f2 : 1;
26+
};
27+
// CHECK-DAG: %S4 = OpTypeStruct %long
28+
29+
[numthreads(1, 1, 1)]
30+
void main() {
31+
S1 s1;
32+
s1.f1 = 3;
33+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s1 %int_0
34+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ulong [[ptr]]
35+
// 0xffffffff00000000
36+
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_18446744069414584320
37+
// CHECK: [[val:%[0-9]+]] = OpBitcast %ulong %ulong_3
38+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[val]] %uint_0
39+
// 0x00000000ffffffff
40+
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_4294967295
41+
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ulong [[dst]] [[src]]
42+
// CHECK: OpStore [[ptr]] [[mix]]
43+
44+
s1.f2 = 1;
45+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s1 %int_0
46+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ulong [[ptr]]
47+
// 0xfffffffeffffffff
48+
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_18446744069414584319
49+
// CHECK: [[val:%[0-9]+]] = OpBitcast %ulong %ulong_1
50+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[val]] %uint_32
51+
// 0x0000000100000000
52+
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_4294967296
53+
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ulong [[dst]] [[src]]
54+
// CHECK: OpStore [[ptr]] [[mix]]
55+
56+
S2 s2;
57+
s2.f1 = 2;
58+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ushort %s2 %int_0
59+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort [[ptr]]
60+
// 0xfff0
61+
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_65520
62+
// CHECK: [[val:%[0-9]+]] = OpBitcast %ushort %ushort_2
63+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ushort [[val]] %uint_0
64+
// 0x000f
65+
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_15
66+
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ushort [[dst]] [[src]]
67+
// CHECK: OpStore [[ptr]] [[mix]]
68+
69+
s2.f2 = 3;
70+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ushort %s2 %int_0
71+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort [[ptr]]
72+
// 0xfe0f
73+
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_65039
74+
// CHECK: [[val:%[0-9]+]] = OpBitcast %ushort %ushort_3
75+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ushort [[val]] %uint_4
76+
// 0x01f0
77+
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_496
78+
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ushort [[dst]] [[src]]
79+
// CHECK: OpStore [[ptr]] [[mix]]
80+
81+
S3 s3;
82+
s3.f1 = 5;
83+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s3 %int_0
84+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ulong [[ptr]]
85+
// 0xffffe00000000000
86+
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_18446708889337462784
87+
// CHECK: [[val:%[0-9]+]] = OpBitcast %ulong %ulong_5
88+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[val]] %uint_0
89+
// 0x00001fffffffffff
90+
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_35184372088831
91+
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ulong [[dst]] [[src]]
92+
// CHECK: OpStore [[ptr]] [[mix]]
93+
94+
s3.f2 = 6;
95+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ulong %s3 %int_0
96+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ulong [[ptr]]
97+
// 0xff801fffffffffff
98+
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_18410750461062676479
99+
// CHECK: [[val:%[0-9]+]] = OpBitcast %ulong %ulong_6
100+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ulong [[val]] %uint_45
101+
// 0x007fe00000000000
102+
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ulong [[tmp]] %ulong_35993612646875136
103+
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ulong [[dst]] [[src]]
104+
// CHECK: OpStore [[ptr]] [[mix]]
105+
106+
s3.f3 = 7;
107+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_ushort %s3 %int_1
108+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %ushort [[ptr]]
109+
// 0xff80
110+
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_65408
111+
// CHECK: [[val:%[0-9]+]] = OpBitcast %ushort %ushort_7
112+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %ushort [[val]] %uint_0
113+
// 0x007f
114+
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %ushort [[tmp]] %ushort_127
115+
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %ushort [[dst]] [[src]]
116+
// CHECK: OpStore [[ptr]] [[mix]]
117+
118+
s3.f4 = 8;
119+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_uint %s3 %int_2
120+
// CHECK: [[val:%[0-9]+]] = OpLoad %uint [[ptr]]
121+
// CHECK: [[tmp:%[0-9]+]] = OpBitFieldInsert %uint [[val]] %uint_8 %uint_0 %uint_5
122+
// CHECK: OpStore [[ptr]] [[tmp]]
123+
124+
S4 s4;
125+
s4.f1 = 3;
126+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_long %s4 %int_0
127+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %long [[ptr]]
128+
// 0xffffffff00000000
129+
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %long [[tmp]] %ulong_18446744069414584320
130+
// CHECK: [[val:%[0-9]+]] = OpBitcast %long %long_3
131+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %long [[val]] %uint_0
132+
// 0x00000000ffffffff
133+
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %long [[tmp]] %ulong_4294967295
134+
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %long [[dst]] [[src]]
135+
// CHECK: OpStore [[ptr]] [[mix]]
136+
137+
s4.f2 = 1;
138+
// CHECK: [[ptr:%[0-9]+]] = OpAccessChain %_ptr_Function_long %s4 %int_0
139+
// CHECK: [[tmp:%[0-9]+]] = OpLoad %long [[ptr]]
140+
// 0xfffffffeffffffff
141+
// CHECK: [[dst:%[0-9]+]] = OpBitwiseAnd %long [[tmp]] %ulong_18446744069414584319
142+
// CHECK: [[val:%[0-9]+]] = OpBitcast %long %long_1
143+
// CHECK: [[tmp:%[0-9]+]] = OpShiftLeftLogical %long [[val]] %uint_32
144+
// 0x0000000100000000
145+
// CHECK: [[src:%[0-9]+]] = OpBitwiseAnd %long [[tmp]] %ulong_4294967296
146+
// CHECK: [[mix:%[0-9]+]] = OpBitwiseOr %long [[dst]] [[src]]
147+
// CHECK: OpStore [[ptr]] [[mix]]
148+
}

0 commit comments

Comments
 (0)