Skip to content

Commit 594919c

Browse files
authored
[mlir][spirv] Split conditional basic blocks during deserialization (#127639)
With the current design some of the values are sank into a selection region, despite them being also used outside that region. This is because the current deserializer logic sinks the entire basic block containing a conditional branch forming a header of a selection construct, without accounting for some values being used outside. This manifests as (for example): ``` <unknown>:0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct <unknown>:0: note: see current operation: %0 = "spirv.Variable"()<{storage_class = #spirv.storage_class<Function>}> : () -> !spirv.ptr<vector<4xf32>, Function> ``` The proposed solution to this problem is to split the conditional basic block into two, one block containing just the conditional branch, and other the rest of instructions. By doing this, the logic that structures selection regions, only sinks the comparison, keeping the rest of instructions outside the selection region. A SPIR-V test is required, as the problem can happen only during deserialization and cannot be tested with `--test-spirv-roundtrip`. An MLIR test exhibiting the problematic behaviour would be an incorrect MLIR in the first place. This solution is proposed as an alternative to an unfinished PR #123371, that is unlikely to be merged in the foreseeable future, as the author "stepped away from this for a time being". There is also a Discourse thread: https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494 that tried to solicit some feedback on the topic.
1 parent 0caa8f4 commit 594919c

File tree

5 files changed

+126
-0
lines changed

5 files changed

+126
-0
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,13 +2158,72 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
21582158
return success();
21592159
}
21602160

2161+
LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2162+
// Create a copy, so we can modify keys in the original.
2163+
BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
2164+
for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
2165+
it != e; ++it) {
2166+
auto &[block, mergeInfo] = *it;
2167+
2168+
// Skip processing loop regions. For loop regions continueBlock is non-null.
2169+
if (mergeInfo.continueBlock)
2170+
continue;
2171+
2172+
if (!block->mightHaveTerminator())
2173+
continue;
2174+
2175+
Operation *terminator = block->getTerminator();
2176+
assert(terminator);
2177+
2178+
if (!isa<spirv::BranchConditionalOp>(terminator))
2179+
continue;
2180+
2181+
// Do not split blocks that only contain a conditional branch, i.e., block
2182+
// size is <= 1.
2183+
if (block->begin() != block->end() &&
2184+
std::next(block->begin()) != block->end()) {
2185+
Block *newBlock = block->splitBlock(terminator);
2186+
OpBuilder builder(block, block->end());
2187+
builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
2188+
2189+
// If the split block was a merge block of another region we need to
2190+
// update the map.
2191+
for (auto it = blockMergeInfo.begin(); it != blockMergeInfo.end(); ++it) {
2192+
auto &[ignore, mergeInfo] = *it;
2193+
if (mergeInfo.mergeBlock == block) {
2194+
mergeInfo.mergeBlock = newBlock;
2195+
}
2196+
}
2197+
2198+
// After splitting we need to update the map to use the new block as a
2199+
// header.
2200+
blockMergeInfo.erase(block);
2201+
blockMergeInfo.try_emplace(newBlock, mergeInfo);
2202+
}
2203+
}
2204+
2205+
return success();
2206+
}
2207+
21612208
LogicalResult spirv::Deserializer::structurizeControlFlow() {
21622209
LLVM_DEBUG({
21632210
logger.startLine()
21642211
<< "//----- [cf] start structurizing control flow -----//\n";
21652212
logger.indent();
21662213
});
21672214

2215+
LLVM_DEBUG({
2216+
logger.startLine() << "[cf] split conditional blocks\n";
2217+
logger.startLine() << "\n";
2218+
});
2219+
2220+
if (failed(splitConditionalBlocks())) {
2221+
return failure();
2222+
}
2223+
2224+
// TODO: This loop is non-deterministic. Iteration order may vary between runs
2225+
// for the same shader as the key to the map is a pointer. See:
2226+
// https://github.com/llvm/llvm-project/issues/128547
21682227
while (!blockMergeInfo.empty()) {
21692228
Block *headerBlock = blockMergeInfo.begin()->first;
21702229
BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;

mlir/lib/Target/SPIRV/Deserialization/Deserializer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ class Deserializer {
246246
return opBuilder.getStringAttr(attrName);
247247
}
248248

249+
// Move a conditional branch into a separate basic block to avoid sinking
250+
// defs that are required outside a selection region.
251+
LogicalResult splitConditionalBlocks();
252+
249253
//===--------------------------------------------------------------------===//
250254
// Type
251255
//===--------------------------------------------------------------------===//

mlir/test/Target/SPIRV/selection.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
105105
%var = spirv.Variable : !spirv.ptr<i1, Function>
106106
// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
107107
// CHECK-NEXT: ^[[BB]]:
108+
// CHECK: spirv.Branch ^[[BB:.+]]
109+
// CHECK-NEXT: ^[[BB]]:
108110

109111
// CHECK-NEXT: spirv.mlir.selection {
110112
spirv.mlir.selection {

mlir/test/Target/SPIRV/selection.spv

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
2+
3+
; COM: The purpose of this test is to check that a variable (in this case %color) that
4+
; COM: is defined before a selection region and used both in the selection region and
5+
; COM: after it, is not sunk into that selection region by the deserializer. If the
6+
; COM: variable is sunk, then it cannot be accessed outside the region and causes
7+
; COM: control-flow structurization to fail.
8+
9+
; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
10+
; CHECK: spirv.func @main() "None" {
11+
; CHECK: spirv.Variable : !spirv.ptr<vector<4xf32>, Function>
12+
; CHECK: spirv.mlir.selection {
13+
; CHECK-NEXT: spirv.BranchConditional {{.*}}, ^[[bb:.+]], ^[[bb:.+]]
14+
; CHECK-NEXT: ^[[bb:.+]]
15+
; CHECK: spirv.Branch ^[[bb:.+]]
16+
; CHECK-NEXT: ^[[bb:.+]]:
17+
; CHECK-NEXT: spirv.mlir.merge
18+
; CHECK-NEXT: }
19+
; CHECK: spirv.Return
20+
; CHECK-NEXT: }
21+
; CHECK: }
22+
23+
OpCapability Shader
24+
%2 = OpExtInstImport "GLSL.std.450"
25+
OpMemoryModel Logical GLSL450
26+
OpEntryPoint Fragment %main "main" %colorOut
27+
OpExecutionMode %main OriginUpperLeft
28+
OpDecorate %colorOut Location 0
29+
%void = OpTypeVoid
30+
%4 = OpTypeFunction %void
31+
%float = OpTypeFloat 32
32+
%v4float = OpTypeVector %float 4
33+
%fun_v4float = OpTypePointer Function %v4float
34+
%float_1 = OpConstant %float 1
35+
%float_0 = OpConstant %float 0
36+
%13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
37+
%out_v4float = OpTypePointer Output %v4float
38+
%colorOut = OpVariable %out_v4float Output
39+
%uint = OpTypeInt 32 0
40+
%uint_0 = OpConstant %uint 0
41+
%out_float = OpTypePointer Output %float
42+
%bool = OpTypeBool
43+
%25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
44+
%main = OpFunction %void None %4
45+
%6 = OpLabel
46+
%color = OpVariable %fun_v4float Function
47+
OpStore %color %13
48+
%19 = OpAccessChain %out_float %colorOut %uint_0
49+
%20 = OpLoad %float %19
50+
%22 = OpFOrdEqual %bool %20 %float_1
51+
OpSelectionMerge %24 None
52+
OpBranchConditional %22 %23 %24
53+
%23 = OpLabel
54+
OpStore %color %25
55+
OpBranch %24
56+
%24 = OpLabel
57+
%26 = OpLoad %v4float %color
58+
OpStore %colorOut %26
59+
OpReturn
60+
OpFunctionEnd

mlir/test/lit.cfg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
".test",
4444
".pdll",
4545
".c",
46+
".spv",
4647
]
4748

4849
# test_source_root: The root path where tests are located.

0 commit comments

Comments
 (0)