Skip to content

Commit 1075180

Browse files
committed
[mlir][spirv] Split conditional basic block during deserialization
With the current design some of the values are sank into a selection region, despite them being also used outside that region. This is because the current deserializer logic sinks the entire basic block containing a conditional branch forming a header of a selection construct, without accounting for some values being used outside. This manifests as (for example): ``` <unknown>:0: error: 'spirv.Variable' op failed control flow structurization: it has uses outside of the enclosing selection/loop construct <unknown>:0: note: see current operation: %0 = "spirv.Variable"()<{storage_class = #spirv.storage_class<Function>}> : () -> !spirv.ptr<vector<4xf32>, Function> ``` The proposed solution to this problem is to split the conditional basic block into two, one block containing just the conditional branch, and other the rest of instructions. By doing this, the logic that structures selection regions, only sinks the comparison, keeping the rest of instructions outside the selection region. A SPIR-V test is required, as the problem can happen only during deserialization and cannot be tested with `--test-spirv-roundtrip`. An MLIR test exhibiting the problematic behaviour would be an incorrect MLIR in the first place. This solution is proposed as an alternative to an unfinished PR #123371, that is unlikely to be merged in the foreseeable future, as the author "stepped away from this for a time being". There is also a Discourse thread: https://discourse.llvm.org/t/spir-v-uses-outside-the-selection-region/84494 that tried, unfortunately unsuccessful, solicit some feedback on the topic.
1 parent 6646b65 commit 1075180

File tree

6 files changed

+91
-0
lines changed

6 files changed

+91
-0
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,13 +2158,55 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
21582158
return success();
21592159
}
21602160

2161+
LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2162+
auto splitBlock = [&](Block *block) {
2163+
// Do not split loop headers
2164+
if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
2165+
if (it->second.continueBlock) {
2166+
return;
2167+
}
2168+
}
2169+
2170+
if (!block->mightHaveTerminator())
2171+
return;
2172+
2173+
auto terminator = block->getTerminator();
2174+
assert(terminator != nullptr);
2175+
2176+
if (isa<spirv::BranchConditionalOp>(terminator) &&
2177+
std::distance(block->begin(), block->end()) > 1) {
2178+
auto newBlock = block->splitBlock(terminator);
2179+
OpBuilder builder(block, block->end());
2180+
builder.create<spirv::BranchOp>(block->getParent()->getLoc(), newBlock);
2181+
2182+
if (auto it = blockMergeInfo.find(block); it != blockMergeInfo.end()) {
2183+
auto value = std::move(it->second);
2184+
blockMergeInfo.erase(it);
2185+
blockMergeInfo.try_emplace(newBlock, std::move(value));
2186+
}
2187+
}
2188+
};
2189+
curFunction->walk(splitBlock);
2190+
2191+
return success();
2192+
}
2193+
21612194
LogicalResult spirv::Deserializer::structurizeControlFlow() {
21622195
LLVM_DEBUG({
21632196
logger.startLine()
21642197
<< "//----- [cf] start structurizing control flow -----//\n";
21652198
logger.indent();
21662199
});
21672200

2201+
LLVM_DEBUG({
2202+
logger.startLine() << "[cf] split conditional blocks\n";
2203+
logger.startLine() << "\n";
2204+
});
2205+
2206+
if (failed(splitConditionalBlocks())) {
2207+
return failure();
2208+
}
2209+
21682210
while (!blockMergeInfo.empty()) {
21692211
Block *headerBlock = blockMergeInfo.begin()->first;
21702212
BlockMergeInfo mergeInfo = blockMergeInfo.begin()->second;

mlir/lib/Target/SPIRV/Deserialization/Deserializer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,10 @@ class Deserializer {
246246
return opBuilder.getStringAttr(attrName);
247247
}
248248

249+
// Move a conditional branch into a separate basic block to avoid sinking
250+
// defs that are required outside a selection region.
251+
LogicalResult splitConditionalBlocks();
252+
249253
//===--------------------------------------------------------------------===//
250254
// Type
251255
//===--------------------------------------------------------------------===//

mlir/test/Target/SPIRV/loop.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Addre
267267
}
268268
// CHECK-NEXT: %[[LOAD:.+]] = spirv.Load "Function" %[[VAR]] : i1
269269
%load = spirv.Load "Function" %var : i1
270+
// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
271+
// CHECK-NEXT: ^[[BB]]
270272
// CHECK-NEXT: spirv.BranchConditional %[[LOAD]], ^[[CONTINUE:.+]](%[[ARG1]] : i64), ^[[LOOP_MERGE:.+]]
271273
spirv.BranchConditional %load, ^continue(%arg1 : i64), ^loop_merge
272274
// CHECK-NEXT: ^[[CONTINUE]](%[[ARG2:.+]]: i64):

mlir/test/Target/SPIRV/selection.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
105105
%var = spirv.Variable : !spirv.ptr<i1, Function>
106106
// CHECK-NEXT: spirv.Branch ^[[BB:.+]]
107107
// CHECK-NEXT: ^[[BB]]:
108+
// CHECK: spirv.Branch ^[[BB:.+]]
109+
// CHECK-NEXT: ^[[BB]]:
108110

109111
// CHECK-NEXT: spirv.mlir.selection {
110112
spirv.mlir.selection {
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
2+
; CHECK: spirv.module
3+
OpCapability Shader
4+
%2 = OpExtInstImport "GLSL.std.450"
5+
OpMemoryModel Logical GLSL450
6+
OpEntryPoint Fragment %main "main" %colorOut
7+
OpExecutionMode %main OriginUpperLeft
8+
OpDecorate %colorOut Location 0
9+
%void = OpTypeVoid
10+
%4 = OpTypeFunction %void
11+
%float = OpTypeFloat 32
12+
%v4float = OpTypeVector %float 4
13+
%fun_v4float = OpTypePointer Function %v4float
14+
%float_1 = OpConstant %float 1
15+
%float_0 = OpConstant %float 0
16+
%13 = OpConstantComposite %v4float %float_1 %float_0 %float_0 %float_1
17+
%out_v4float = OpTypePointer Output %v4float
18+
%colorOut = OpVariable %out_v4float Output
19+
%uint = OpTypeInt 32 0
20+
%uint_0 = OpConstant %uint 0
21+
%out_float = OpTypePointer Output %float
22+
%bool = OpTypeBool
23+
%25 = OpConstantComposite %v4float %float_1 %float_1 %float_0 %float_1
24+
%main = OpFunction %void None %4
25+
%6 = OpLabel
26+
%color = OpVariable %fun_v4float Function
27+
OpStore %color %13
28+
%19 = OpAccessChain %out_float %colorOut %uint_0
29+
%20 = OpLoad %float %19
30+
%22 = OpFOrdEqual %bool %20 %float_1
31+
OpSelectionMerge %24 None
32+
OpBranchConditional %22 %23 %24
33+
%23 = OpLabel
34+
OpStore %color %25
35+
OpBranch %24
36+
%24 = OpLabel
37+
%26 = OpLoad %v4float %color
38+
OpStore %colorOut %26
39+
OpReturn
40+
OpFunctionEnd

mlir/test/lit.cfg.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
".test",
4444
".pdll",
4545
".c",
46+
".spv"
4647
]
4748

4849
# test_source_root: The root path where tests are located.

0 commit comments

Comments
 (0)