Skip to content

Commit a61f634

Browse files
committed
[mlir][spirv] Split conditional basic blocks during deserialization
With the current design some of the values are sank into a selection region, despite them being also used outside that region. This is because the current deserializer logic sinks the entire basic block containing a conditional branch forming a header of a selection construct, without accounting for some values being used outside. This manifests as (for example): ``` <unknown>:0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct <unknown>:0: note: see current operation: %0 = "spirv.Variable"()<{storage_class = #spirv.storage_class<Function>}> : () -> !spirv.ptr<vector<4xf32>, Function> ``` The proposed solution to this problem is to split the conditional basic block into two, one block containing just the conditional branch, and other the rest of instructions. By doing this, the logic that structures selection regions, only sinks the comparison, keeping the rest of instructions outside the selection region. A SPIR-V test is required, as the problem can happen only during deserialization and cannot be tested with `--test-spirv-roundtrip`. An MLIR test exhibiting the problematic behaviour would be an incorrect MLIR in the first place. This solution is proposed as an alternative to an unfinished PR #123371, that is unlikely to be merged in the foreseeable future, as the author "stepped away from this for a time being". There is also a Discourse thread: https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494 that tried to solicit some feedback on the topic.
1 parent 6646b65 commit a61f634

File tree

6 files changed

+107
-0
lines changed

6 files changed

+107
-0
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,13 +2158,55 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
21582158
return success();
21592159
}
21602160

2161+
LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2162+
auto splitBlock = [&](Block *block) {
2163+
// Do not split loop headers
2164+
if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
2165+
if (it->second.continueBlock) {
2166+
return;
2167+
}
2168+
}
2169+
2170+
if (!block->mightHaveTerminator())
2171+
return;
2172+
2173+
Operation *terminator = block->getTerminator();
2174+
assert(terminator != nullptr);
2175+
2176+
if (isa<spirv::BranchConditionalOp>(terminator) &&
2177+
std::distance(block->begin(), block->end()) > 1) {
2178+
Block *newBlock = block->splitBlock(terminator);
2179+
OpBuilder builder(block, block->end());
2180+
builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
2181+
2182+
if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
2183+
BlockMergeInfo value = std::move(it->second);
2184+
blockMergeInfo.erase(it);
2185+
blockMergeInfo.try_emplace(newBlock, std::move(value));
2186+
}
2187+
}
2188+
};
2189+
curFunction->walk(splitBlock);
2190+
2191+
return success();
2192+
}
2193+
21612194
LogicalResult spirv::Deserializer::structurizeControlFlow() {
21622195
LLVM_DEBUG({
21632196
logger.startLine()
21642197
<< "//----- [cf] start structurizing control flow -----//\n";
21652198
logger.indent();
21662199
});
21672200

2201+
LLVM_DEBUG({
2202+
logger.startLine() << "[cf] split conditional blocks\n";
2203+
logger.startLine() << "\n";
2204+
});
2205+
2206+
if (failed(splitConditionalBlocks())) {
2207+
return failure();
2208+
}
2209+
21682210
while (!blockMergeInfo.empty()) {
21692211
Block *headerBlock = blockMergeInfo.begin()->first;
21702212
BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;

mlir/lib/Target/SPIRV/Deserialization/Deserializer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ class Deserializer {
246246
return opBuilder.getStringAttr(attrName);
247247
}
248248

249+
// Move a conditional branch into a separate basic block to avoid sinking
250+
// defs that are required outside a selection region.
251+
LogicalResult splitConditionalBlocks();
252+
249253
//===--------------------------------------------------------------------===//
250254
// Type
251255
//===--------------------------------------------------------------------===//

mlir/test/Target/SPIRV/loop.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Addre
267267
}
268268
// CHECK-NEXT: %[[LOAD:.+]] = spirv.Load "Function" %[[VAR]] : i1
269269
%load = spirv.Load "Function" %var : i1
270+
// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
271+
// CHECK-NEXT: ^[[BB]]
270272
// CHECK-NEXT: spirv.BranchConditional %[[LOAD]], ^[[CONTINUE:.+]](%[[ARG1]] : i64), ^[[LOOP_MERGE:.+]]
271273
spirv.BranchConditional %load, ^continue(%arg1 : i64), ^loop_merge
272274
// CHECK-NEXT: ^[[CONTINUE]](%[[ARG2:.+]]: i64):

mlir/test/Target/SPIRV/selection.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
105105
%var = spirv.Variable : !spirv.ptr<i1, Function>
106106
// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
107107
// CHECK-NEXT: ^[[BB]]:
108+
// CHECK: spirv.Branch ^[[BB:.+]]
109+
// CHECK-NEXT: ^[[BB]]:
108110

109111
// CHECK-NEXT: spirv.mlir.selection {
110112
spirv.mlir.selection {
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
2+
; COM: The purpose of this test is to check that a variable (in this case %color) that
3+
; COM: is defined before a selection region and used both in the selection region and
4+
; COM: after it, is not sunk into that selection region by the deserializer. If the
5+
; COM: variable is sunk, then it cannot be accessed outside the region and causes
6+
; COM: control-flow structurization to fail.
7+
; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []>
8+
OpCapability Shader
9+
%2 = OpExtInstImport "GLSL.std.450"
10+
OpMemoryModel Logical GLSL450
11+
OpEntryPoint Fragment %main "main" %colorOut
12+
OpExecutionMode %main OriginUpperLeft
13+
OpDecorate %colorOut Location 0
14+
%void = OpTypeVoid
15+
%4 = OpTypeFunction %void
16+
%float = OpTypeFloat 32
17+
%v4float = OpTypeVector %float 4
18+
%fun_v4float = OpTypePointer Function %v4float
19+
%float_1 = OpConstant %float 1
20+
%float_0 = OpConstant %float 0
21+
%13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
22+
%out_v4float = OpTypePointer Output %v4float
23+
%colorOut = OpVariable %out_v4float Output
24+
%uint = OpTypeInt 32 0
25+
%uint_0 = OpConstant %uint 0
26+
%out_float = OpTypePointer Output %float
27+
%bool = OpTypeBool
28+
%25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
29+
; CHECK: spirv.func @main() "None" {
30+
%main = OpFunction %void None %4
31+
%6 = OpLabel
32+
; CHECK: spirv.Variable : !spirv.ptr<vector<4xf32>, Function>
33+
%color = OpVariable %fun_v4float Function
34+
OpStore %color %13
35+
%19 = OpAccessChain %out_float %colorOut %uint_0
36+
%20 = OpLoad %float %19
37+
%22 = OpFOrdEqual %bool %20 %float_1
38+
; CHECK: spirv.mlir.selection {
39+
OpSelectionMerge %24 None
40+
; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
41+
OpBranchConditional %22 %23 %24
42+
; CHECK-NEXT: ^[[bb:.+]]
43+
%23 = OpLabel
44+
OpStore %color %25
45+
; CHECK: spirv.Branch ^[[bb:.+]]
46+
OpBranch %24
47+
; CHECK-NEXT: ^[[bb:.+]]:
48+
%24 = OpLabel
49+
; CHECK-NEXT: spirv.mlir.merge
50+
; CHECK-NEXT: }
51+
%26 = OpLoad %v4float %color
52+
OpStore %colorOut %26
53+
; CHECK: spirv.Return
54+
OpReturn
55+
; CHECK-NEXT: }
56+
OpFunctionEnd

mlir/test/lit.cfg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
".test",
4444
".pdll",
4545
".c",
46+
".spv",
4647
]
4748

4849
# test_source_root: The root path where tests are located.

0 commit comments

Comments
 (0)