Skip to content

Commit 9c154fb

Browse files
authored
[SPIR-V] Fix flattened SPIR-V variable codegen. (microsoft#6756)
When the [branch] annotation is used, switches are converted into an if/else tree. Issue arise when declaring a variable into a switch case: - when flattened, the same variable could be traversed twice in case of a case fall-through. In addition, there was a small bug in the flattening logic: only breaks were stopping the handling of further statements, while early-returns were ignored. This was not an important bug as it only added more dead-code, but it was wrong. Fixes microsoft#6718 Signed-off-by: Nathan Gauër <[email protected]>
1 parent 1028410 commit 9c154fb

File tree

3 files changed

+97
-4
lines changed

3 files changed

+97
-4
lines changed

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,6 +1083,9 @@ void DeclResultIdMapper::createCounterVarForDecl(const DeclaratorDecl *decl) {
10831083
SpirvVariable *
10841084
DeclResultIdMapper::createFnVar(const VarDecl *var,
10851085
llvm::Optional<SpirvInstruction *> init) {
1086+
if (astDecls[var].instr != nullptr)
1087+
return cast<SpirvVariable>(astDecls[var].instr);
1088+
10861089
const auto type = getTypeOrFnRetType(var);
10871090
const auto loc = var->getLocation();
10881091
const auto name = var->getName();
@@ -1095,10 +1098,7 @@ DeclResultIdMapper::createFnVar(const VarDecl *var,
10951098
bool isAlias = false;
10961099
(void)getTypeAndCreateCounterForPotentialAliasVar(var, &isAlias);
10971100
varInstr->setContainsAliasComponent(isAlias);
1098-
1099-
assert(astDecls[var].instr == nullptr);
11001101
astDecls[var].instr = varInstr;
1101-
11021102
return varInstr;
11031103
}
11041104

tools/clang/lib/SPIRV/SpirvEmitter.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13768,7 +13768,9 @@ void SpirvEmitter::processSwitchStmtUsingIfStmts(const SwitchStmt *switchStmt) {
1376813768
// current case.
1376913769
std::vector<Stmt *> statements;
1377013770
unsigned i = curCaseIndex + 1;
13771-
for (; i < flatSwitch.size() && !isa<BreakStmt>(flatSwitch[i]); ++i) {
13771+
for (; i < flatSwitch.size() && !isa<BreakStmt>(flatSwitch[i]) &&
13772+
!isa<ReturnStmt>(flatSwitch[i]);
13773+
++i) {
1377213774
if (!isa<CaseStmt>(flatSwitch[i]) && !isa<DefaultStmt>(flatSwitch[i]))
1377313775
statements.push_back(const_cast<Stmt *>(flatSwitch[i]));
1377413776
}
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// RUN: %dxc -T cs_6_0 -E main -fcgl %s -spirv | FileCheck %s
2+
3+
void simple() {
4+
uint a = 0;
5+
6+
// CHECK: [[a:%[0-9]+]] = OpLoad %uint %a
7+
// CHECK: [[cond:%[0-9]+]] = OpIEqual %bool [[a]] %uint_0
8+
// CHECK: OpBranchConditional [[cond]] %if_true %if_false
9+
10+
// CHECK: %if_true = OpLabel
11+
// CHECK: OpStore %d1 %uint_0
12+
// CHECK: OpBranch %if_merge_0
13+
// CHECK: %if_false = OpLabel
14+
// CHECK: [[a:%[0-9]+]] = OpLoad %uint %a
15+
// CHECK: [[cond:%[0-9]+]] = OpIEqual %bool [[a]] %uint_1
16+
// CHECK: OpBranchConditional [[cond]] %if_true_0 %if_false_0
17+
18+
// CHECK: %if_true_0 = OpLabel
19+
// CHECK: OpStore %d1_0 %uint_1
20+
// CHECK: OpBranch %if_merge
21+
// CHECK: %if_false_0 = OpLabel
22+
// CHECK: OpBranch %if_merge
23+
24+
// CHECK: %if_merge = OpLabel
25+
// CHECK: OpBranch %if_merge_0
26+
27+
// CHECK: %if_merge_0 = OpLabel
28+
// CHECK: OpReturn
29+
[branch]
30+
switch (a) {
31+
default:
32+
return;
33+
case 0: {
34+
uint d1 = 0;
35+
return;
36+
}
37+
case 1: {
38+
uint d1 = 1;
39+
}
40+
}
41+
}
42+
43+
// CHECK: [[b:%[0-9]+]] = OpLoad %uint %b
44+
// CHECK: [[cond:%[0-9]+]] = OpIEqual %bool [[b]] %uint_0
45+
// CHECK: OpBranchConditional [[cond]] %if_true_1 %if_false_1
46+
47+
// CHECK: %if_true_1 = OpLabel
48+
// CHECK: OpStore %v1 %uint_0
49+
// CHECK: OpStore %v1_0 %uint_2
50+
// CHECK: OpStore %v2 %uint_1
51+
// CHECK: OpBranch %if_merge_2
52+
// CHECK: %if_false_1 = OpLabel
53+
// CHECK: [[b:%[0-9]+]] = OpLoad %uint %b
54+
// CHECK: [[cond:%[0-9]+]] = OpIEqual %bool [[b]] %uint_1
55+
// CHECK: OpBranchConditional [[cond]] %if_true_2 %if_false_2
56+
57+
// CHECK: %if_true_2 = OpLabel
58+
// CHECK: OpStore %v1_0 %uint_2
59+
// CHECK: OpStore %v2 %uint_1
60+
// CHECK: OpBranch %if_merge_1
61+
// CHECK: %if_false_2 = OpLabel
62+
// CHECK: OpBranch %if_merge_1
63+
64+
// CHECK: %if_merge_1 = OpLabel
65+
// CHECK: OpBranch %if_merge_2
66+
67+
// CHECK: %if_merge_2 = OpLabel
68+
// CHECK: OpReturn
69+
void fallthrough() {
70+
uint b = 0;
71+
72+
[branch]
73+
switch (b) {
74+
default:
75+
return;
76+
case 0: {
77+
uint v1 = 0;
78+
}
79+
case 1: {
80+
uint v1 = 2;
81+
uint v2 = 1;
82+
}
83+
}
84+
}
85+
86+
[numthreads(1, 1, 1)]
87+
void main()
88+
{
89+
simple();
90+
fallthrough();
91+
}

0 commit comments

Comments
 (0)