Skip to content

Commit b293d45

Browse files
IgWod-IMGkcloudy0717
authored andcommitted
[mlir][spirv] Enable block splitting for spirv.Switch (llvm#170147)
This is not strictly necessary as now selection regions can yield values, however splitting the block simplifies the code as it avoids unnecessary values being sunk just to be later yielded.
1 parent 1455118 commit b293d45

File tree

3 files changed

+81
-12
lines changed

3 files changed

+81
-12
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2868,7 +2868,7 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
28682868
return success();
28692869
}
28702870

2871-
LogicalResult spirv::Deserializer::splitConditionalBlocks() {
2871+
LogicalResult spirv::Deserializer::splitSelectionHeader() {
28722872
// Create a copy, so we can modify keys in the original.
28732873
BlockMergeInfoMap blockMergeInfoCopy = blockMergeInfo;
28742874
for (auto it = blockMergeInfoCopy.begin(), e = blockMergeInfoCopy.end();
@@ -2885,7 +2885,7 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
28852885
Operation *terminator = block->getTerminator();
28862886
assert(terminator);
28872887

2888-
if (!isa<spirv::BranchConditionalOp>(terminator))
2888+
if (!isa<spirv::BranchConditionalOp, spirv::SwitchOp>(terminator))
28892889
continue;
28902890

28912891
// Check if the current header block is a merge block of another construct.
@@ -2895,10 +2895,10 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
28952895
splitHeaderMergeBlock = true;
28962896
}
28972897

2898-
// Do not split a block that only contains a conditional branch, unless it
2899-
// is also a merge block of another construct - in that case we want to
2900-
// split the block. We do not want two constructs to share header / merge
2901-
// block.
2898+
// Do not split a block that only contains a conditional branch / switch,
2899+
// unless it is also a merge block of another construct - in that case we
2900+
// want to split the block. We do not want two constructs to share header /
2901+
// merge block.
29022902
if (!llvm::hasSingleElement(*block) || splitHeaderMergeBlock) {
29032903
Block *newBlock = block->splitBlock(terminator);
29042904
OpBuilder builder(block, block->end());
@@ -2936,7 +2936,7 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
29362936
logger.startLine() << "\n";
29372937
});
29382938

2939-
if (failed(splitConditionalBlocks())) {
2939+
if (failed(splitSelectionHeader())) {
29402940
return failure();
29412941
}
29422942

mlir/lib/Target/SPIRV/Deserialization/Deserializer.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,11 @@ class Deserializer {
280280
return opBuilder.getStringAttr(attrName);
281281
}
282282

283-
/// Move a conditional branch into a separate basic block to avoid unnecessary
284-
/// sinking of defs that may be required outside a selection region. This
285-
/// function also ensures that a single block cannot be a header block of one
286-
/// selection construct and the merge block of another.
287-
LogicalResult splitConditionalBlocks();
283+
/// Move a conditional branch or a switch into a separate basic block to avoid
284+
/// unnecessary sinking of defs that may be required outside a selection
285+
/// region. This function also ensures that a single block cannot be a header
286+
/// block of one selection construct and the merge block of another.
287+
LogicalResult splitSelectionHeader();
288288

289289
//===--------------------------------------------------------------------===//
290290
// Type
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
; RUN: %if spirv-tools %{ spirv-as --target-env spv1.0 %s -o - | mlir-translate --deserialize-spirv - -o - | FileCheck %s %}
2+
3+
; This test is analogous to selection.spv but tests switch op.
4+
5+
; CHECK: spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
6+
; CHECK-NEXT: spirv.func @switch({{%.*}}: si32) "None" {
7+
; CHECK: {{%.*}} = spirv.Constant 1.000000e+00 : f32
8+
; CHECK-NEXT: {{%.*}} = spirv.Undef : vector<3xf32>
9+
; CHECK-NEXT: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32>
10+
; CHECK-NEXT: spirv.Branch ^[[bb:.+]]
11+
; CHECK-NEXT: ^[[bb:.+]]:
12+
; CHECK-NEXT: {{%.*}} = spirv.mlir.selection -> vector<3xf32> {
13+
; CHECK-NEXT: spirv.Switch {{%.*}} : si32, [
14+
; CHECK-NEXT: default: ^[[bb:.+]]({{%.*}}: vector<3xf32>),
15+
; CHECK-NEXT: 0: ^[[bb:.+]]({{%.*}}: vector<3xf32>),
16+
; CHECK-NEXT: 1: ^[[bb:.+]]({{%.*}}: vector<3xf32>)
17+
; CHECK: ^[[bb:.+]]({{%.*}}: vector<3xf32>):
18+
; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>)
19+
; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>):
20+
; CHECK: spirv.Branch ^[[bb:.+]]({{%.*}}: vector<3xf32>)
21+
; CHECK-NEXT: ^[[bb:.+]]({{%.*}}: vector<3xf32>):
22+
; CHECK-NEXT: spirv.mlir.merge %8 : vector<3xf32>
23+
; CHECK-NEXT }
24+
; CHECK: spirv.Return
25+
; CHECK-NEXT: }
26+
; CHECK: }
27+
28+
OpCapability Shader
29+
OpMemoryModel Logical GLSL450
30+
OpEntryPoint GLCompute %main "main"
31+
OpExecutionMode %main LocalSize 1 1 1
32+
OpName %switch "switch"
33+
OpName %main "main"
34+
%void = OpTypeVoid
35+
%int = OpTypeInt 32 1
36+
%1 = OpTypeFunction %void %int
37+
%float = OpTypeFloat 32
38+
%float_1 = OpConstant %float 1
39+
%v3float = OpTypeVector %float 3
40+
%9 = OpUndef %v3float
41+
%float_3 = OpConstant %float 3
42+
%float_4 = OpConstant %float 4
43+
%float_2 = OpConstant %float 2
44+
%25 = OpTypeFunction %void
45+
%switch = OpFunction %void None %1
46+
%5 = OpFunctionParameter %int
47+
%6 = OpLabel
48+
OpBranch %12
49+
%12 = OpLabel
50+
%11 = OpCompositeInsert %v3float %float_1 %9 0
51+
OpSelectionMerge %15 None
52+
OpSwitch %5 %15 0 %13 1 %14
53+
%13 = OpLabel
54+
%16 = OpPhi %v3float %11 %12
55+
%18 = OpCompositeInsert %v3float %float_3 %16 1
56+
OpBranch %15
57+
%14 = OpLabel
58+
%19 = OpPhi %v3float %11 %12
59+
%21 = OpCompositeInsert %v3float %float_4 %19 1
60+
OpBranch %15
61+
%15 = OpLabel
62+
%22 = OpPhi %v3float %21 %14 %18 %13 %11 %12
63+
%24 = OpCompositeInsert %v3float %float_2 %22 2
64+
OpReturn
65+
OpFunctionEnd
66+
%main = OpFunction %void None %25
67+
%27 = OpLabel
68+
OpReturn
69+
OpFunctionEnd

0 commit comments

Comments
 (0)