Skip to content

Commit 5e933ab

Browse files
authored
Handle local arrays of RWStructrued buffers (microsoft#5523)
This commit adds support for local arrays of RWStructuredBuffers.
1 parent 969ff6e commit 5e933ab

File tree

4 files changed

+66
-6
lines changed

4 files changed

+66
-6
lines changed

tools/clang/lib/SPIRV/AstTypeProbe.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,6 +1001,9 @@ bool isAKindOfStructuredOrByteBuffer(QualType type) {
10011001
}
10021002

10031003
bool isOrContainsAKindOfStructuredOrByteBuffer(QualType type) {
1004+
while (type->isArrayType())
1005+
type = type->getAsArrayTypeUnsafe()->getElementType();
1006+
10041007
if (const RecordType *recordType = type->getAs<RecordType>()) {
10051008
StringRef name = recordType->getDecl()->getName();
10061009
if (name == "StructuredBuffer" || name == "RWStructuredBuffer" ||

tools/clang/lib/SPIRV/LowerTypeVisitor.cpp

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -530,10 +530,24 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
530530
// Array type
531531
if (const auto *arrayType = astContext.getAsArrayType(type)) {
532532
const auto elemType = arrayType->getElementType();
533-
const auto *loweredElemType =
534-
lowerType(arrayType->getElementType(), rule, isRowMajor, srcLoc);
535-
llvm::Optional<uint32_t> arrayStride = llvm::None;
536533

534+
// If layout rule is void, it means these resource types are used for
535+
// declaring local resources. This should be lowered to a pointer to the
536+
// array.
537+
//
538+
// The pointer points to the Uniform storage class, and the element type
539+
// should have the corresponding layout.
540+
bool isLocalStructuredOrByteBuffer =
541+
isAKindOfStructuredOrByteBuffer(elemType) &&
542+
rule == SpirvLayoutRule::Void;
543+
544+
SpirvLayoutRule elementLayoutRule =
545+
(isLocalStructuredOrByteBuffer ? getCodeGenOptions().sBufferLayoutRule
546+
: rule);
547+
const SpirvType *loweredElemType =
548+
lowerType(elemType, elementLayoutRule, isRowMajor, srcLoc);
549+
550+
llvm::Optional<uint32_t> arrayStride = llvm::None;
537551
if (rule != SpirvLayoutRule::Void &&
538552
// We won't have stride information for structured/byte buffers since
539553
// they contain runtime arrays.
@@ -544,13 +558,23 @@ const SpirvType *LowerTypeVisitor::lowerType(QualType type,
544558
arrayStride = stride;
545559
}
546560

561+
const SpirvType *spirvArrayType = nullptr;
547562
if (const auto *caType = astContext.getAsConstantArrayType(type)) {
548563
const auto size = static_cast<uint32_t>(caType->getSize().getZExtValue());
549-
return spvContext.getArrayType(loweredElemType, size, arrayStride);
564+
spirvArrayType =
565+
spvContext.getArrayType(loweredElemType, size, arrayStride);
566+
} else {
567+
assert(type->isIncompleteArrayType());
568+
spirvArrayType =
569+
spvContext.getRuntimeArrayType(loweredElemType, arrayStride);
570+
}
571+
572+
if (isLocalStructuredOrByteBuffer) {
573+
return spvContext.getPointerType(spirvArrayType,
574+
spv::StorageClass::Uniform);
550575
}
551576

552-
assert(type->isIncompleteArrayType());
553-
return spvContext.getRuntimeArrayType(loweredElemType, arrayStride);
577+
return spirvArrayType;
554578
}
555579

556580
// Reference types
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// RUN: %dxc -T ps_6_6 -E main -O0 -fvk-allow-rwstructuredbuffer-arrays
2+
3+
struct PSInput
4+
{
5+
uint idx : COLOR;
6+
};
7+
8+
// CHECK: OpDecorate %g_rwbuffer DescriptorSet 2
9+
// CHECK: OpDecorate %g_rwbuffer Binding 0
10+
// CHECK: OpDecorate %counter_var_g_rwbuffer DescriptorSet 2
11+
// CHECK: OpDecorate %counter_var_g_rwbuffer Binding 1
12+
13+
// CHECK: %g_rwbuffer = OpVariable %_ptr_Uniform__arr_type_RWStructuredBuffer_uint_uint_5 Uniform
14+
// CHECK: %counter_var_g_rwbuffer = OpVariable %_ptr_Uniform__arr_type_ACSBuffer_counter_uint_5 Uniform
15+
RWStructuredBuffer<uint> g_rwbuffer[5] : register(u0, space2);
16+
17+
float4 main(PSInput input) : SV_TARGET
18+
{
19+
RWStructuredBuffer<uint> l_rwbuffer[5] = g_rwbuffer;
20+
21+
// CHECK: [[ac1:%\d+]] = OpAccessChain %_ptr_Uniform_type_ACSBuffer_counter %counter_var_g_rwbuffer %int_0
22+
// CHECK: [[ac2:%\d+]] = OpAccessChain %_ptr_Uniform_int [[ac1]] %uint_0
23+
// CHECK: OpAtomicIAdd %int [[ac2]] %uint_1 %uint_0 %int_1
24+
l_rwbuffer[0].IncrementCounter();
25+
26+
// CHECK: [[ac1:%\d+]] = OpAccessChain %_ptr_Uniform_type_RWStructuredBuffer_uint %g_rwbuffer {{%\d+}}
27+
// CHECK: [[ac2:%\d+]] = OpAccessChain %_ptr_Uniform_uint [[ac1]] %int_0 %uint_0
28+
// CHECK: OpLoad %uint [[ac2]]
29+
return l_rwbuffer[input.idx][0];
30+
}

tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,9 @@ TEST_F(FileTest, RWStructuredBufferArrayCounterFlattened) {
162162
TEST_F(FileTest, RWStructuredBufferArrayCounterIndirect) {
163163
runFileTest("type.rwstructured-buffer.array.counter.indirect.hlsl");
164164
}
165+
TEST_F(FileTest, RWStructuredBufferArrayCounterIndirect2) {
166+
runFileTest("type.rwstructured-buffer.array.counter.indirect2.hlsl");
167+
}
165168
TEST_F(FileTest, RWStructuredBufferArrayBindAttributes) {
166169
runFileTest("type.rwstructured-buffer.array.binding.attributes.hlsl");
167170
}

0 commit comments

Comments
 (0)