Skip to content

Commit c01b4f4

Browse files
authored
Iterate over the spir-v fields to handle bitfields (microsoft#6746)
The code that implements `RWByteAddressBuffer::Store` will iterate over all of the fields in a struct to write each element in the struct. However, it does not use the "Spir-V fields", which accounts for multiple fields being packed into the same bitfield. This is fixed by using the `forEachSpirvField` function to make sure that the bitfield are correctly handled. Fixes microsoft#6483
1 parent 74ba845 commit c01b4f4

File tree

3 files changed

+98
-63
lines changed

3 files changed

+98
-63
lines changed

tools/clang/lib/SPIRV/RawBufferMethods.cpp

Lines changed: 71 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "RawBufferMethods.h"
1010
#include "AlignmentSizeCalculator.h"
11+
#include "LowerTypeVisitor.h"
1112
#include "clang/AST/ASTContext.h"
1213
#include "clang/AST/CharUnits.h"
1314
#include "clang/AST/RecordLayout.h"
@@ -284,44 +285,48 @@ SpirvInstruction *RawBufferHandler::processTemplatedLoadFromBuffer(
284285
// aligned like their field with the largest alignment.
285286
// As a result, there might exist some padding after some struct members.
286287
if (const auto *structType = targetType->getAs<RecordType>()) {
287-
const auto *decl = structType->getDecl();
288+
LowerTypeVisitor lowerTypeVisitor(astContext, theEmitter.getSpirvContext(),
289+
theEmitter.getSpirvOptions(), spvBuilder);
290+
auto *decl = targetType->getAsTagDecl();
291+
assert(decl && "Expected all structs to be tag decls.");
292+
const StructType *spvType = dyn_cast<StructType>(lowerTypeVisitor.lowerType(
293+
targetType, theEmitter.getSpirvOptions().sBufferLayoutRule, llvm::None,
294+
decl->getLocation()));
288295
llvm::SmallVector<SpirvInstruction *, 4> loadedElems;
289-
uint32_t fieldOffsetInBytes = 0;
290-
uint32_t structAlignment = 0, structSize = 0, stride = 0;
291-
std::tie(structAlignment, structSize) =
292-
AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
293-
.getAlignmentAndSize(targetType,
294-
theEmitter.getSpirvOptions().sBufferLayoutRule,
295-
llvm::None, &stride);
296-
for (const auto *field : decl->fields()) {
297-
AlignmentSizeCalculator alignmentCalc(astContext,
298-
theEmitter.getSpirvOptions());
299-
uint32_t fieldSize = 0, fieldAlignment = 0;
300-
std::tie(fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize(
301-
field->getType(), theEmitter.getSpirvOptions().sBufferLayoutRule,
302-
/*isRowMajor*/ llvm::None, &stride);
303-
fieldOffsetInBytes = roundToPow2(fieldOffsetInBytes, fieldAlignment);
304-
auto *byteOffset = address.getByteAddress();
305-
if (fieldOffsetInBytes != 0) {
306-
byteOffset = spvBuilder.createBinaryOp(
307-
spv::Op::OpIAdd, astContext.UnsignedIntTy, byteOffset,
308-
spvBuilder.getConstantInt(astContext.UnsignedIntTy,
309-
llvm::APInt(32, fieldOffsetInBytes)),
310-
loc, range);
311-
}
312-
313-
loadedElems.push_back(processTemplatedLoadFromBuffer(
314-
buffer, byteOffset, field->getType(), range));
296+
forEachSpirvField(
297+
structType, spvType,
298+
[this, &buffer, &address, range,
299+
&loadedElems](size_t spirvFieldIndex, const QualType &fieldType,
300+
const auto &field) {
301+
auto *baseOffset = address.getByteAddress();
302+
if (field.offset.hasValue() && field.offset.getValue() != 0) {
303+
const auto loc = buffer->getSourceLocation();
304+
SpirvConstant *offset = spvBuilder.getConstantInt(
305+
astContext.UnsignedIntTy,
306+
llvm::APInt(32, field.offset.getValue()));
307+
baseOffset = spvBuilder.createBinaryOp(
308+
spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset, offset,
309+
loc, range);
310+
}
315311

316-
fieldOffsetInBytes += fieldSize;
317-
}
312+
loadedElems.push_back(processTemplatedLoadFromBuffer(
313+
buffer, baseOffset, fieldType, range));
314+
return true;
315+
});
318316

319317
// After we're done with loading the entire struct, we need to update the
320318
// byteAddress (in case we are loading an array of structs).
321319
//
322320
// struct size = 34 bytes (34 / 8) = 4 full words (34 % 8) = 2 > 0,
323321
// therefore need to move to the next aligned address So the starting byte
324322
// offset after loading the entire struct is: 8 * (4 + 1) = 40
323+
uint32_t structAlignment = 0, structSize = 0, stride = 0;
324+
std::tie(structAlignment, structSize) =
325+
AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
326+
.getAlignmentAndSize(targetType,
327+
theEmitter.getSpirvOptions().sBufferLayoutRule,
328+
llvm::None, &stride);
329+
325330
assert(structAlignment != 0);
326331
SpirvInstruction *structWidth = spvBuilder.getConstantInt(
327332
astContext.UnsignedIntTy,
@@ -577,7 +582,7 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
577582
return;
578583
default:
579584
theEmitter.emitError(
580-
"templated load of ByteAddressBuffer is only implemented for "
585+
"templated store of ByteAddressBuffer is only implemented for "
581586
"16, 32, and 64-bit types",
582587
loc);
583588
return;
@@ -604,40 +609,36 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
604609
// aligned like their field with the largest alignment.
605610
// As a result, there might exist some padding after some struct members.
606611
if (const auto *structType = valueType->getAs<RecordType>()) {
607-
const auto *decl = structType->getDecl();
608-
uint32_t fieldOffsetInBytes = 0;
609-
uint32_t structAlignment = 0, structSize = 0, stride = 0;
610-
std::tie(structAlignment, structSize) =
611-
AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
612-
.getAlignmentAndSize(valueType,
613-
theEmitter.getSpirvOptions().sBufferLayoutRule,
614-
llvm::None, &stride);
615-
uint32_t fieldIndex = 0;
616-
for (const auto *field : decl->fields()) {
617-
AlignmentSizeCalculator alignmentCalc(astContext,
618-
theEmitter.getSpirvOptions());
619-
uint32_t fieldSize = 0, fieldAlignment = 0;
620-
std::tie(fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize(
621-
field->getType(), theEmitter.getSpirvOptions().sBufferLayoutRule,
622-
/*isRowMajor*/ llvm::None, &stride);
623-
fieldOffsetInBytes = roundToPow2(fieldOffsetInBytes, fieldAlignment);
624-
auto *byteOffset = address.getByteAddress();
625-
if (fieldOffsetInBytes != 0) {
626-
byteOffset = spvBuilder.createBinaryOp(
627-
spv::Op::OpIAdd, astContext.UnsignedIntTy, byteOffset,
628-
spvBuilder.getConstantInt(astContext.UnsignedIntTy,
629-
llvm::APInt(32, fieldOffsetInBytes)),
630-
loc, range);
631-
}
612+
LowerTypeVisitor lowerTypeVisitor(astContext, theEmitter.getSpirvContext(),
613+
theEmitter.getSpirvOptions(), spvBuilder);
614+
auto *decl = valueType->getAsTagDecl();
615+
assert(decl && "Expected all structs to be tag decls.");
616+
const StructType *spvType = dyn_cast<StructType>(lowerTypeVisitor.lowerType(
617+
valueType, theEmitter.getSpirvOptions().sBufferLayoutRule, llvm::None,
618+
decl->getLocation()));
619+
assert(spvType);
620+
forEachSpirvField(
621+
structType, spvType,
622+
[this, &address, loc, range, buffer, value](size_t spirvFieldIndex,
623+
const QualType &fieldType,
624+
const auto &field) {
625+
auto *baseOffset = address.getByteAddress();
626+
if (field.offset.hasValue() && field.offset.getValue() != 0) {
627+
SpirvConstant *offset = spvBuilder.getConstantInt(
628+
astContext.UnsignedIntTy,
629+
llvm::APInt(32, field.offset.getValue()));
630+
baseOffset = spvBuilder.createBinaryOp(
631+
spv::Op::OpIAdd, astContext.UnsignedIntTy, baseOffset, offset,
632+
loc, range);
633+
}
632634

633-
processTemplatedStoreToBuffer(
634-
spvBuilder.createCompositeExtract(field->getType(), value,
635-
{fieldIndex}, loc, range),
636-
buffer, byteOffset, field->getType(), range);
637-
638-
fieldOffsetInBytes += fieldSize;
639-
++fieldIndex;
640-
}
635+
processTemplatedStoreToBuffer(
636+
spvBuilder.createCompositeExtract(
637+
fieldType, value, {static_cast<uint32_t>(spirvFieldIndex)},
638+
loc, range),
639+
buffer, baseOffset, fieldType, range);
640+
return true;
641+
});
641642

642643
// After we're done with storing the entire struct, we need to update the
643644
// byteAddress (in case we are storing an array of structs).
@@ -647,6 +648,13 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
647648
// (34 % 8) = 2 > 0, therefore need to move to the next aligned address
648649
// So the starting byte offset after loading the entire struct is:
649650
// 8 * (4 + 1) = 40
651+
uint32_t structAlignment = 0, structSize = 0, stride = 0;
652+
std::tie(structAlignment, structSize) =
653+
AlignmentSizeCalculator(astContext, theEmitter.getSpirvOptions())
654+
.getAlignmentAndSize(valueType,
655+
theEmitter.getSpirvOptions().sBufferLayoutRule,
656+
llvm::None, &stride);
657+
650658
assert(structAlignment != 0);
651659
auto *structWidth = spvBuilder.getConstantInt(
652660
astContext.UnsignedIntTy,

tools/clang/test/CodeGenSPIRV/method.byte-address-buffer.load.hlsl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
ByteAddressBuffer myBuffer;
44

5+
struct S {
6+
uint32_t x : 8;
7+
uint32_t y : 8;
8+
};
9+
510
[numthreads(1, 1, 1)]
611
void main() {
712
uint addr = 0;
@@ -50,4 +55,11 @@ void main() {
5055
// CHECK-NEXT: [[load4_word3:%[0-9]+]] = OpLoad %uint [[load_ptr6]]
5156
// CHECK-NEXT: {{%[0-9]+}} = OpCompositeConstruct %v4uint [[load4_word0]] [[load4_word1]] [[load4_word2]] [[load4_word3]]
5257
uint4 word4 = myBuffer.Load4(addr);
58+
59+
// CHECK: [[idx:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
60+
// CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %myBuffer %uint_0 [[idx]]
61+
// CHECK: [[bitfield:%[0-9]+]] = OpLoad %uint [[ac]]
62+
// CHECK: [[s:%[0-9]+]] = OpCompositeConstruct %S [[bitfield]]
63+
// CHECK: OpStore %s [[s]]
64+
S s = myBuffer.Load<S>(0);
5365
}

tools/clang/test/CodeGenSPIRV/method.byte-address-buffer.store.hlsl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
RWByteAddressBuffer outBuffer;
44

5+
struct S
6+
{
7+
uint32_t x:8;
8+
uint32_t y:8;
9+
};
10+
511
[numthreads(1, 1, 1)]
612
void main() {
713
uint addr = 0;
@@ -67,4 +73,13 @@ void main() {
6773
// CHECK-NEXT: [[outBufPtr3:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[baseAddr_plus3]]
6874
// CHECK-NEXT: OpStore [[outBufPtr3]] [[word3]]
6975
outBuffer.Store4(addr, words4);
76+
77+
// CHECK: [[s:%[0-9]+]] = OpLoad %S %s
78+
// CHECK: [[bitfield:%[0-9]+]] = OpCompositeExtract %uint [[s]] 0
79+
// CHECK: [[idx:%[0-9]+]] = OpShiftRightLogical %uint %uint_0 %uint_2
80+
// CHECK: [[ac:%[0-9]+]] = OpAccessChain %_ptr_Uniform_uint %outBuffer %uint_0 [[idx]]
81+
// CHECK: OpStore [[ac]] [[bitfield]]
82+
S s = (S)0;
83+
s.x = 5;
84+
outBuffer.Store(0, s);
7085
}

0 commit comments

Comments
 (0)