Skip to content

Commit 4c8d21a

Browse files
authored
[SPIRV] Handle member travrersal with template type (#7674)
The SPIR-V backend does not handle SubstTemplateTypeParmType types when traversing getting the SPIR-V fields. In these cases, we will always want the replacement type. So we modify forEachSpirvField to do that in all cases. Fixes #7178
1 parent 1e3da15 commit 4c8d21a

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

tools/clang/lib/SPIRV/AstTypeProbe.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,7 +1601,10 @@ void forEachSpirvField(
16011601
uint32_t lastConvertedIndex = 0;
16021602
size_t astFieldIndex = 0;
16031603
for (const auto &base : cxxDecl->bases()) {
1604-
const auto &type = base.getType();
1604+
auto type = base.getType();
1605+
if (auto *templatedType = dyn_cast<SubstTemplateTypeParmType>(type))
1606+
type = templatedType->getReplacementType();
1607+
16051608
const auto &spirvField = spirvType->getFields()[astFieldIndex];
16061609
if (!operation(spirvField.fieldIndex, type, spirvField)) {
16071610
return;
@@ -1620,7 +1623,10 @@ void forEachSpirvField(
16201623
continue;
16211624
}
16221625

1623-
const auto &type = field->getType();
1626+
auto type = field->getType();
1627+
if (auto *templatedType = dyn_cast<SubstTemplateTypeParmType>(type))
1628+
type = templatedType->getReplacementType();
1629+
16241630
if (!operation(currentFieldIndex, type, spirvField)) {
16251631
return;
16261632
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: %dxc -T cs_6_8 -E main %s -spirv | FileCheck %s
2+
3+
struct T
4+
{
5+
float3 direction;
6+
};
7+
8+
template<class TT>
9+
struct S
10+
{
11+
TT L;
12+
};
13+
14+
RWStructuredBuffer< S<T> > o;
15+
16+
// CHECK: [[f120:%.+]] = OpConstant %float 120
17+
// CHECK: [[v120:%.+]] = OpConstantComposite %v3float [[f120]] [[f120]] [[f120]]
18+
// CHECK: [[T:%.+]] = OpConstantComposite %T [[v120]]
19+
// CHECK: [[S:%.+]] = OpConstantComposite %S [[T]]
20+
21+
[numthreads(32, 32, 1)]
22+
void main(uint32_t threadID : SV_DispatchThreadID)
23+
{
24+
uint32_t infinity = 0x78;
25+
S<T> s = (S<T>)infinity;
26+
27+
// CHECK: OpStore {{%.*}} [[S]]
28+
o[0] = s;
29+
}

0 commit comments

Comments
 (0)