8
8
9
9
#include " RawBufferMethods.h"
10
10
#include " AlignmentSizeCalculator.h"
11
+ #include " LowerTypeVisitor.h"
11
12
#include " clang/AST/ASTContext.h"
12
13
#include " clang/AST/CharUnits.h"
13
14
#include " clang/AST/RecordLayout.h"
@@ -284,44 +285,48 @@ SpirvInstruction *RawBufferHandler::processTemplatedLoadFromBuffer(
284
285
// aligned like their field with the largest alignment.
285
286
// As a result, there might exist some padding after some struct members.
286
287
if (const auto *structType = targetType->getAs <RecordType>()) {
287
- const auto *decl = structType->getDecl ();
288
+ LowerTypeVisitor lowerTypeVisitor (astContext, theEmitter.getSpirvContext (),
289
+ theEmitter.getSpirvOptions (), spvBuilder);
290
+ auto *decl = targetType->getAsTagDecl ();
291
+ assert (decl && " Expected all structs to be tag decls." );
292
+ const StructType *spvType = dyn_cast<StructType>(lowerTypeVisitor.lowerType (
293
+ targetType, theEmitter.getSpirvOptions ().sBufferLayoutRule , llvm::None,
294
+ decl->getLocation ()));
288
295
llvm::SmallVector<SpirvInstruction *, 4 > loadedElems;
289
- uint32_t fieldOffsetInBytes = 0 ;
290
- uint32_t structAlignment = 0 , structSize = 0 , stride = 0 ;
291
- std::tie (structAlignment, structSize) =
292
- AlignmentSizeCalculator (astContext, theEmitter.getSpirvOptions ())
293
- .getAlignmentAndSize (targetType,
294
- theEmitter.getSpirvOptions ().sBufferLayoutRule ,
295
- llvm::None, &stride);
296
- for (const auto *field : decl->fields ()) {
297
- AlignmentSizeCalculator alignmentCalc (astContext,
298
- theEmitter.getSpirvOptions ());
299
- uint32_t fieldSize = 0 , fieldAlignment = 0 ;
300
- std::tie (fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize (
301
- field->getType (), theEmitter.getSpirvOptions ().sBufferLayoutRule ,
302
- /* isRowMajor*/ llvm::None, &stride);
303
- fieldOffsetInBytes = roundToPow2 (fieldOffsetInBytes, fieldAlignment);
304
- auto *byteOffset = address.getByteAddress ();
305
- if (fieldOffsetInBytes != 0 ) {
306
- byteOffset = spvBuilder.createBinaryOp (
307
- spv::Op::OpIAdd, astContext.UnsignedIntTy , byteOffset,
308
- spvBuilder.getConstantInt (astContext.UnsignedIntTy ,
309
- llvm::APInt (32 , fieldOffsetInBytes)),
310
- loc, range);
311
- }
312
-
313
- loadedElems.push_back (processTemplatedLoadFromBuffer (
314
- buffer, byteOffset, field->getType (), range));
296
+ forEachSpirvField (
297
+ structType, spvType,
298
+ [this , &buffer, &address, range,
299
+ &loadedElems](size_t spirvFieldIndex, const QualType &fieldType,
300
+ const auto &field) {
301
+ auto *baseOffset = address.getByteAddress ();
302
+ if (field.offset .hasValue () && field.offset .getValue () != 0 ) {
303
+ const auto loc = buffer->getSourceLocation ();
304
+ SpirvConstant *offset = spvBuilder.getConstantInt (
305
+ astContext.UnsignedIntTy ,
306
+ llvm::APInt (32 , field.offset .getValue ()));
307
+ baseOffset = spvBuilder.createBinaryOp (
308
+ spv::Op::OpIAdd, astContext.UnsignedIntTy , baseOffset, offset,
309
+ loc, range);
310
+ }
315
311
316
- fieldOffsetInBytes += fieldSize;
317
- }
312
+ loadedElems.push_back (processTemplatedLoadFromBuffer (
313
+ buffer, baseOffset, fieldType, range));
314
+ return true ;
315
+ });
318
316
319
317
// After we're done with loading the entire struct, we need to update the
320
318
// byteAddress (in case we are loading an array of structs).
321
319
//
322
320
// struct size = 34 bytes (34 / 8) = 4 full words (34 % 8) = 2 > 0,
323
321
// therefore need to move to the next aligned address So the starting byte
324
322
// offset after loading the entire struct is: 8 * (4 + 1) = 40
323
+ uint32_t structAlignment = 0 , structSize = 0 , stride = 0 ;
324
+ std::tie (structAlignment, structSize) =
325
+ AlignmentSizeCalculator (astContext, theEmitter.getSpirvOptions ())
326
+ .getAlignmentAndSize (targetType,
327
+ theEmitter.getSpirvOptions ().sBufferLayoutRule ,
328
+ llvm::None, &stride);
329
+
325
330
assert (structAlignment != 0 );
326
331
SpirvInstruction *structWidth = spvBuilder.getConstantInt (
327
332
astContext.UnsignedIntTy ,
@@ -577,7 +582,7 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
577
582
return ;
578
583
default :
579
584
theEmitter.emitError (
580
- " templated load of ByteAddressBuffer is only implemented for "
585
+ " templated store of ByteAddressBuffer is only implemented for "
581
586
" 16, 32, and 64-bit types" ,
582
587
loc);
583
588
return ;
@@ -604,40 +609,36 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
604
609
// aligned like their field with the largest alignment.
605
610
// As a result, there might exist some padding after some struct members.
606
611
if (const auto *structType = valueType->getAs <RecordType>()) {
607
- const auto *decl = structType->getDecl ();
608
- uint32_t fieldOffsetInBytes = 0 ;
609
- uint32_t structAlignment = 0 , structSize = 0 , stride = 0 ;
610
- std::tie (structAlignment, structSize) =
611
- AlignmentSizeCalculator (astContext, theEmitter.getSpirvOptions ())
612
- .getAlignmentAndSize (valueType,
613
- theEmitter.getSpirvOptions ().sBufferLayoutRule ,
614
- llvm::None, &stride);
615
- uint32_t fieldIndex = 0 ;
616
- for (const auto *field : decl->fields ()) {
617
- AlignmentSizeCalculator alignmentCalc (astContext,
618
- theEmitter.getSpirvOptions ());
619
- uint32_t fieldSize = 0 , fieldAlignment = 0 ;
620
- std::tie (fieldAlignment, fieldSize) = alignmentCalc.getAlignmentAndSize (
621
- field->getType (), theEmitter.getSpirvOptions ().sBufferLayoutRule ,
622
- /* isRowMajor*/ llvm::None, &stride);
623
- fieldOffsetInBytes = roundToPow2 (fieldOffsetInBytes, fieldAlignment);
624
- auto *byteOffset = address.getByteAddress ();
625
- if (fieldOffsetInBytes != 0 ) {
626
- byteOffset = spvBuilder.createBinaryOp (
627
- spv::Op::OpIAdd, astContext.UnsignedIntTy , byteOffset,
628
- spvBuilder.getConstantInt (astContext.UnsignedIntTy ,
629
- llvm::APInt (32 , fieldOffsetInBytes)),
630
- loc, range);
631
- }
612
+ LowerTypeVisitor lowerTypeVisitor (astContext, theEmitter.getSpirvContext (),
613
+ theEmitter.getSpirvOptions (), spvBuilder);
614
+ auto *decl = valueType->getAsTagDecl ();
615
+ assert (decl && " Expected all structs to be tag decls." );
616
+ const StructType *spvType = dyn_cast<StructType>(lowerTypeVisitor.lowerType (
617
+ valueType, theEmitter.getSpirvOptions ().sBufferLayoutRule , llvm::None,
618
+ decl->getLocation ()));
619
+ assert (spvType);
620
+ forEachSpirvField (
621
+ structType, spvType,
622
+ [this , &address, loc, range, buffer, value](size_t spirvFieldIndex,
623
+ const QualType &fieldType,
624
+ const auto &field) {
625
+ auto *baseOffset = address.getByteAddress ();
626
+ if (field.offset .hasValue () && field.offset .getValue () != 0 ) {
627
+ SpirvConstant *offset = spvBuilder.getConstantInt (
628
+ astContext.UnsignedIntTy ,
629
+ llvm::APInt (32 , field.offset .getValue ()));
630
+ baseOffset = spvBuilder.createBinaryOp (
631
+ spv::Op::OpIAdd, astContext.UnsignedIntTy , baseOffset, offset,
632
+ loc, range);
633
+ }
632
634
633
- processTemplatedStoreToBuffer (
634
- spvBuilder.createCompositeExtract (field->getType (), value,
635
- {fieldIndex}, loc, range),
636
- buffer, byteOffset, field->getType (), range);
637
-
638
- fieldOffsetInBytes += fieldSize;
639
- ++fieldIndex;
640
- }
635
+ processTemplatedStoreToBuffer (
636
+ spvBuilder.createCompositeExtract (
637
+ fieldType, value, {static_cast <uint32_t >(spirvFieldIndex)},
638
+ loc, range),
639
+ buffer, baseOffset, fieldType, range);
640
+ return true ;
641
+ });
641
642
642
643
// After we're done with storing the entire struct, we need to update the
643
644
// byteAddress (in case we are storing an array of structs).
@@ -647,6 +648,13 @@ void RawBufferHandler::processTemplatedStoreToBuffer(SpirvInstruction *value,
647
648
// (34 % 8) = 2 > 0, therefore need to move to the next aligned address
648
649
// So the starting byte offset after loading the entire struct is:
649
650
// 8 * (4 + 1) = 40
651
+ uint32_t structAlignment = 0 , structSize = 0 , stride = 0 ;
652
+ std::tie (structAlignment, structSize) =
653
+ AlignmentSizeCalculator (astContext, theEmitter.getSpirvOptions ())
654
+ .getAlignmentAndSize (valueType,
655
+ theEmitter.getSpirvOptions ().sBufferLayoutRule ,
656
+ llvm::None, &stride);
657
+
650
658
assert (structAlignment != 0 );
651
659
auto *structWidth = spvBuilder.getConstantInt (
652
660
astContext.UnsignedIntTy ,
0 commit comments