Skip to content

Commit 8497f1d

Browse files
authored
[SPIRV] Handle bit fields with initializer lists (microsoft#6301)
The spir-v backend currently treats each bitfield as a separate member in the spir-v type when processing an initializer list. This commit changes that so that it matches the behaviour of the DXIL backend. In particular, if `{bf1, bf2}` is used to initialize a struct with two bitfield, then `bf1` and bf2` will be merged into a single field when constructing the struct. Fixes microsoft#5688
1 parent 557759b commit 8497f1d

File tree

7 files changed

+129
-92
lines changed

7 files changed

+129
-92
lines changed

tools/clang/include/clang/SPIRV/AstTypeProbe.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "dxc/Support/SPIRVOptions.h"
1414
#include "clang/AST/Decl.h"
1515
#include "clang/AST/Type.h"
16+
#include "clang/SPIRV/SpirvType.h"
1617
#include "clang/Sema/Sema.h"
1718

1819
namespace clang {
@@ -383,6 +384,28 @@ bool isStructureContainingAnyKindOfBuffer(QualType type);
383384
/// types, or it's an array of scalar, vector, or matrix of numeric types.
384385
bool isScalarOrNonStructAggregateOfNumericalTypes(QualType type);
385386

387+
/// Calls `operation` on for each field in the base and derives class defined by
388+
/// `recordType`. The `operation` will receive the AST type linked to the field,
389+
/// the SPIRV type linked to the field, and the index of the field in the final
390+
/// SPIR-V representation. This index of the field can vary from the AST
391+
/// field-index because bitfields are merged into a single field in the SPIR-V
392+
/// representation.
393+
///
394+
/// If `includeMerged` is true, `operation` will be called on the same spir-v
395+
/// field for each field it represents. For example, if a spir-v field holds the
396+
/// values for 3 bit-fields, `operation` will be called 3 times with the same
397+
/// `spirvFieldIndex`. The `bitfield` information in `field` will be different.
398+
///
399+
/// If false, `operation` will be called once on the first field in the merged
400+
/// field.
401+
///
402+
/// If the operation returns false, we stop processing fields.
403+
void forEachSpirvField(
404+
const RecordType *recordType, const StructType *spirvType,
405+
std::function<bool(size_t spirvFieldIndex, const QualType &fieldType,
406+
const StructType::FieldInfo &field)>
407+
operation,
408+
bool includeMerged = false);
386409
} // namespace spirv
387410
} // namespace clang
388411

tools/clang/lib/SPIRV/AstTypeProbe.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,5 +1548,46 @@ bool isScalarOrNonStructAggregateOfNumericalTypes(QualType type) {
15481548
return false;
15491549
}
15501550

1551+
void forEachSpirvField(
1552+
const RecordType *recordType, const StructType *spirvType,
1553+
std::function<bool(size_t, const QualType &, const StructType::FieldInfo &)>
1554+
operation,
1555+
bool includeMerged) {
1556+
const auto *cxxDecl = recordType->getAsCXXRecordDecl();
1557+
const auto *recordDecl = recordType->getDecl();
1558+
1559+
// Iterate through the base class (one field per base class).
1560+
// Bases cannot be melded into 1 field like bitfields, simple iteration.
1561+
uint32_t lastConvertedIndex = 0;
1562+
size_t astFieldIndex = 0;
1563+
for (const auto &base : cxxDecl->bases()) {
1564+
const auto &type = base.getType();
1565+
const auto &spirvField = spirvType->getFields()[astFieldIndex];
1566+
if (!operation(spirvField.fieldIndex, type, spirvField)) {
1567+
return;
1568+
}
1569+
lastConvertedIndex = spirvField.fieldIndex;
1570+
++astFieldIndex;
1571+
}
1572+
1573+
// Iterate through the derived class fields. Field could be merged.
1574+
for (const auto *field : recordDecl->fields()) {
1575+
const auto &spirvField = spirvType->getFields()[astFieldIndex];
1576+
const uint32_t currentFieldIndex = spirvField.fieldIndex;
1577+
if (!includeMerged && astFieldIndex > 0 &&
1578+
currentFieldIndex == lastConvertedIndex) {
1579+
++astFieldIndex;
1580+
continue;
1581+
}
1582+
1583+
const auto &type = field->getType();
1584+
if (!operation(currentFieldIndex, type, spirvField)) {
1585+
return;
1586+
}
1587+
lastConvertedIndex = currentFieldIndex;
1588+
++astFieldIndex;
1589+
}
1590+
}
1591+
15511592
} // namespace spirv
15521593
} // namespace clang

tools/clang/lib/SPIRV/InitListHandler.cpp

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717
#include <algorithm>
1818
#include <iterator>
1919

20+
#include "LowerTypeVisitor.h"
2021
#include "llvm/ADT/SmallVector.h"
2122

2223
namespace clang {
2324
namespace spirv {
2425

25-
InitListHandler::InitListHandler(const ASTContext &ctx, SpirvEmitter &emitter)
26+
InitListHandler::InitListHandler(ASTContext &ctx, SpirvEmitter &emitter)
2627
: astContext(ctx), theEmitter(emitter),
2728
spvBuilder(emitter.getSpirvBuilder()),
2829
diags(emitter.getDiagnosticsEngine()) {}
@@ -399,39 +400,54 @@ InitListHandler::createInitForStructType(QualType type, SourceLocation srcLoc,
399400
tryToSplitStruct();
400401
}
401402

402-
llvm::SmallVector<SpirvInstruction *, 4> fields;
403-
404-
// Initialize base classes first.
405-
llvm::SmallVector<SpirvInstruction *, 4> base_fields;
406-
const RecordDecl *structDecl = type->getAsStructureType()->getDecl();
407-
if (auto *cxxStructDecl = dyn_cast<CXXRecordDecl>(structDecl)) {
408-
for (CXXBaseSpecifier base : cxxStructDecl->bases()) {
409-
QualType baseType = base.getType();
410-
const RecordType *baseStructType = baseType->getAsStructureType();
411-
if (baseStructType == nullptr) {
412-
continue;
413-
}
414-
const RecordDecl *baseStructDecl = baseStructType->getDecl();
415-
for (const auto *field : baseStructDecl->fields()) {
416-
base_fields.push_back(
417-
createInitForType(field->getType(), field->getLocation(), range));
418-
if (!base_fields.back())
419-
return nullptr;
420-
}
421-
fields.push_back(spvBuilder.createCompositeConstruct(
422-
baseType, base_fields, srcLoc, range));
423-
base_fields.clear();
424-
}
425-
}
403+
const RecordType *recordType = type->getAs<RecordType>();
404+
assert(recordType);
426405

427-
for (const auto *field : structDecl->fields()) {
428-
fields.push_back(
429-
createInitForType(field->getType(), field->getLocation(), range));
430-
if (!fields.back())
431-
return nullptr;
432-
}
406+
LowerTypeVisitor lowerTypeVisitor(astContext, theEmitter.getSpirvContext(),
407+
theEmitter.getSpirvOptions());
408+
const SpirvType *spirvType =
409+
lowerTypeVisitor.lowerType(type, SpirvLayoutRule::Void, false, srcLoc);
433410

434-
// TODO: use OpConstantComposite when all components are constants
411+
llvm::SmallVector<SpirvInstruction *, 4> fields;
412+
const StructType *structType = dyn_cast<StructType>(spirvType);
413+
assert(structType != nullptr);
414+
forEachSpirvField(
415+
recordType, structType,
416+
[this, &fields, srcLoc, range](size_t spirvFieldIndex,
417+
const QualType &fieldType,
418+
const StructType::FieldInfo &fieldInfo) {
419+
SpirvInstruction *init = createInitForType(fieldType, srcLoc, range);
420+
421+
// For non bit-fields, `init` will be the value for the component.
422+
if (!fieldInfo.bitfield.hasValue()) {
423+
assert(fields.size() == fieldInfo.fieldIndex);
424+
fields.push_back(init);
425+
return true;
426+
}
427+
428+
// For a bit fields we need to insert it into the container.
429+
// The first time we see this bit field, init is used as the value.
430+
// This assumes that 0 is the first offset in the bitfield.
431+
if (fields.size() <= fieldInfo.fieldIndex) {
432+
assert(fieldInfo.bitfield->offsetInBits == 0);
433+
fields.push_back(init);
434+
return true;
435+
}
436+
437+
// For the remaining bitfields, we need to insert them into the existing
438+
// container, which is the last element in `fields`.
439+
assert(fields.size() == fieldInfo.fieldIndex + 1);
440+
SpirvInstruction *offset = spvBuilder.getConstantInt(
441+
astContext.UnsignedIntTy,
442+
llvm::APInt(32, fieldInfo.bitfield->offsetInBits));
443+
SpirvInstruction *count = spvBuilder.getConstantInt(
444+
astContext.UnsignedIntTy,
445+
llvm::APInt(32, fieldInfo.bitfield->sizeInBits));
446+
fields.back() = spvBuilder.createBitFieldInsert(
447+
fieldType, fields.back(), init, offset, count, srcLoc);
448+
return true;
449+
},
450+
true);
435451
return spvBuilder.createCompositeConstruct(type, fields, srcLoc, range);
436452
}
437453

tools/clang/lib/SPIRV/InitListHandler.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class InitListHandler {
8080
/// Constructs an InitListHandler which uses the given emitter for normal
8181
/// translation tasks. It will reuse the ModuleBuilder embedded in the given
8282
/// emitter.
83-
explicit InitListHandler(const ASTContext &, SpirvEmitter &);
83+
InitListHandler(ASTContext &ctx, SpirvEmitter &emitter);
8484

8585
/// Processes the given InitListExpr and returns the <result-id> for the final
8686
/// SPIR-V value.
@@ -143,7 +143,7 @@ class InitListHandler {
143143
SourceLocation);
144144

145145
private:
146-
const ASTContext &astContext;
146+
ASTContext &astContext;
147147
SpirvEmitter &theEmitter;
148148
SpirvBuilder &spvBuilder;
149149
DiagnosticsEngine &diags;

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -564,64 +564,6 @@ const StructType *lowerStructType(const SpirvCodeGenOptions &spirvOptions,
564564
return output;
565565
}
566566

567-
// Calls `operation` on for each field in the base and derives class defined by
568-
// `recordType`. The `operation` will receive the AST type linked to the field,
569-
// the SPIRV type linked to the field, and the index of the field in the final
570-
// SPIR-V representation. This index of the field can vary from the AST
571-
// field-index because bitfields are merged into a single field in the SPIR-V
572-
// representation.
573-
//
574-
// If `includeMerged` is true, `operation` will be called on the same spir-v
575-
// field for each field it represents. For example, if a spir-v field holds the
576-
// values for 3 bit-fields, `operation` will be called 3 times with the same
577-
// `spirvFieldIndex`. The `bitfield` information in `field` will be different.
578-
//
579-
// If false, `operation` will be called once on the first field in the merged
580-
// field.
581-
//
582-
// If the operation returns false, we stop processing fields.
583-
void forEachSpirvField(
584-
const RecordType *recordType, const StructType *spirvType,
585-
std::function<bool(size_t spirvFieldIndex, const QualType &fieldType,
586-
const StructType::FieldInfo &field)>
587-
operation,
588-
bool includeMerged = false) {
589-
const auto *cxxDecl = recordType->getAsCXXRecordDecl();
590-
const auto *recordDecl = recordType->getDecl();
591-
592-
// Iterate through the base class (one field per base class).
593-
// Bases cannot be melded into 1 field like bitfields, simple iteration.
594-
uint32_t lastConvertedIndex = 0;
595-
size_t astFieldIndex = 0;
596-
for (const auto &base : cxxDecl->bases()) {
597-
const auto &type = base.getType();
598-
const auto &spirvField = spirvType->getFields()[astFieldIndex];
599-
if (!operation(spirvField.fieldIndex, type, spirvField)) {
600-
return;
601-
}
602-
lastConvertedIndex = spirvField.fieldIndex;
603-
++astFieldIndex;
604-
}
605-
606-
// Iterate through the derived class fields. Field could be merged.
607-
for (const auto *field : recordDecl->fields()) {
608-
const auto &spirvField = spirvType->getFields()[astFieldIndex];
609-
const uint32_t currentFieldIndex = spirvField.fieldIndex;
610-
if (!includeMerged && astFieldIndex > 0 &&
611-
currentFieldIndex == lastConvertedIndex) {
612-
++astFieldIndex;
613-
continue;
614-
}
615-
616-
const auto &type = field->getType();
617-
if (!operation(currentFieldIndex, type, spirvField)) {
618-
return;
619-
}
620-
lastConvertedIndex = currentFieldIndex;
621-
++astFieldIndex;
622-
}
623-
}
624-
625567
} // namespace
626568

627569
SpirvEmitter::SpirvEmitter(CompilerInstance &ci)

tools/clang/lib/SPIRV/SpirvEmitter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ class SpirvEmitter : public ASTConsumer {
6363

6464
ASTContext &getASTContext() { return astContext; }
6565
SpirvBuilder &getSpirvBuilder() { return spvBuilder; }
66+
SpirvContext &getSpirvContext() { return spvContext; }
6667
DiagnosticsEngine &getDiagnosticsEngine() { return diags; }
6768
CompilerInstance &getCompilerInstance() { return theCompilerInstance; }
6869
SpirvCodeGenOptions &getSpirvOptions() { return spirvOptions; }

tools/clang/test/CodeGenSPIRV/var.init.struct.hlsl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ struct W {
4141
float4 color;
4242
};
4343

44+
struct BitFields {
45+
uint R:8;
46+
uint G:8;
47+
uint B:8;
48+
uint A:8;
49+
};
50+
4451
void main() {
4552
// CHECK-LABEL: %bb_entry = OpLabel
4653

@@ -103,4 +110,11 @@ void main() {
103110
// CHECK-NEXT: [[float4_zero:%[0-9]+]] = OpConvertSToF %v4float [[int4_zero]]
104111
// CHECK-NEXT: {{%[0-9]+}} = OpCompositeConstruct %W [[float4_zero]]
105112
W w = { (0).xxxx };
113+
114+
// CHECK: [[v1:%[0-9]+]] = OpBitFieldInsert %uint %uint_3 %uint_2 %uint_8 %uint_8
115+
// CHECK: [[v2:%[0-9]+]] = OpBitFieldInsert %uint [[v1]] %uint_1 %uint_16 %uint_8
116+
// CHECK: [[v3:%[0-9]+]] = OpBitFieldInsert %uint [[v2]] %uint_0 %uint_24 %uint_8
117+
// CHECK: [[bf:%[0-9]+]] = OpCompositeConstruct %BitFields [[v3]]
118+
// CHECK: OpStore %bf [[bf]]
119+
BitFields bf = {3, 2, 1, 0};
106120
}

0 commit comments

Comments
 (0)