Skip to content

Commit 1e3da15

Browse files
authored
[SPIRV] Handle partial template class specialization (#7673)
The SpirvEmitter does not handle ClassTemplatePartialSpecializationDecl. Since it extends RecordDecl, DXC attempts to generate code for the partial specialization, but fails because that is not possible. We need to avoid trying to generate code for the partial specialization and wait for a full specialization. Fixes #7007
1 parent 143da22 commit 1e3da15

File tree

2 files changed

+66
-3
lines changed

2 files changed

+66
-3
lines changed

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,6 +1021,17 @@ void SpirvEmitter::doDecl(const Decl *decl) {
10211021
// functions inside namespaces.
10221022
if (!isa<FunctionDecl>(subDecl))
10231023
doDecl(subDecl);
1024+
} else if (const auto *classTemplateDecl =
1025+
dyn_cast<ClassTemplateDecl>(decl)) {
1026+
doClassTemplateDecl(classTemplateDecl);
1027+
} else if (const auto *classTemplateDecl =
1028+
dyn_cast<ClassTemplatePartialSpecializationDecl>(decl)) {
1029+
// Do nothing. We cannot generate any code with a partial specialization,
1030+
// and when there is a specialization of this decl it will be a
1031+
// specialization of the orginal ClassTemplateDecl that this specializes.
1032+
// The code for the full specialization will be handlded when processing the
1033+
// ClassTemplateDecl. Note that this is also a RecordDecl, so we must check
1034+
// for it before RecordDecl.
10241035
} else if (const auto *funcDecl = dyn_cast<FunctionDecl>(decl)) {
10251036
doFunctionDecl(funcDecl);
10261037
} else if (const auto *bufferDecl = dyn_cast<HLSLBufferDecl>(decl)) {
@@ -1029,9 +1040,6 @@ void SpirvEmitter::doDecl(const Decl *decl) {
10291040
doRecordDecl(recordDecl);
10301041
} else if (const auto *enumDecl = dyn_cast<EnumDecl>(decl)) {
10311042
doEnumDecl(enumDecl);
1032-
} else if (const auto *classTemplateDecl =
1033-
dyn_cast<ClassTemplateDecl>(decl)) {
1034-
doClassTemplateDecl(classTemplateDecl);
10351043
} else if (isa<TypedefNameDecl>(decl)) {
10361044
declIdMapper.recordsSpirvTypeAlias(decl);
10371045
} else if (isa<FunctionTemplateDecl>(decl)) {
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: %dxc -HV 2021 -T cs_6_7 -E main -fcgl %s -spirv | FileCheck %s
2+
3+
template<typename MatT>
4+
struct matrix_traits;
5+
6+
template<typename T, int32_t N, int32_t M>
7+
struct matrix_traits<matrix<T,N,M> >
8+
{
9+
static const uint32_t RowCount = N;
10+
static const uint32_t ColumnCount = M;
11+
};
12+
13+
template<typename MatT>
14+
uint32_t elementCount()
15+
{
16+
return matrix_traits<MatT>::RowCount * matrix_traits<MatT>::ColumnCount;
17+
}
18+
19+
RWBuffer<int> o;
20+
21+
// Initialize the static members at the start of wrapper
22+
// CHECK: %main = OpFunction %void None
23+
// CHECK: OpStore %RowCount %uint_4
24+
// CHECK: OpStore %ColumnCount %uint_4
25+
// CHECK: OpStore %RowCount_0 %uint_3
26+
// CHECK: OpStore %ColumnCount_0 %uint_2
27+
// CHECK: OpFunctionEnd
28+
29+
30+
31+
// CHECK: %src_main = OpFunction %void None
32+
[numthreads(64,1,1)]
33+
void main()
34+
{
35+
// CHECK: OpFunctionCall %uint %elementCount
36+
o[0] = elementCount<float32_t4x4>();
37+
// CHECK: OpFunctionCall %uint %elementCount_0
38+
o[1] = elementCount<float32_t3x2>();
39+
}
40+
41+
// CHECK: %elementCount = OpFunction %uint None
42+
// CHECK-NEXT: OpLabel
43+
// CHECK-NEXT: [[rc:%[0-9]+]] = OpLoad %uint %RowCount
44+
// CHECK-NEXT: [[cc:%[0-9]+]] = OpLoad %uint %ColumnCount
45+
// CHECK-NEXT: [[mul:%[0-9]+]] = OpIMul %uint [[rc]] [[cc]]
46+
// CHECK-NEXT: OpReturnValue [[mul]]
47+
// CHECK-NEXT: OpFunctionEnd
48+
49+
// CHECK: %elementCount_0 = OpFunction %uint None
50+
// CHECK-NEXT: %bb_entry_1 = OpLabel
51+
// CHECK-NEXT: [[rc:%[0-9]+]] = OpLoad %uint %RowCount_0
52+
// CHECK-NEXT: [[cc:%[0-9]+]] = OpLoad %uint %ColumnCount_0
53+
// CHECK-NEXT: [[mul:%[0-9]+]] = OpIMul %uint [[rc]] [[cc]]
54+
// CHECK-NEXT: OpReturnValue [[mul]]
55+
// CHECK-NEXT: OpFunctionEnd

0 commit comments

Comments
 (0)