Skip to content

Commit 1f6bd90

Browse files
committed
[mlir][spirv] Support (de)serialization of block operands in spirv.Switch
1 parent 21fedcb commit 1f6bd90

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2831,6 +2831,23 @@ LogicalResult spirv::Deserializer::wireUpBlockArgument() {
28312831
branchCondOp.getFalseBlock());
28322832

28332833
branchCondOp.erase();
2834+
} else if (auto switchOp = dyn_cast<spirv::SwitchOp>(op)) {
2835+
if (target == switchOp.getDefaultTarget()) {
2836+
SmallVector<ValueRange> targetOperands(switchOp.getTargetOperands());
2837+
DenseIntElementsAttr literals =
2838+
switchOp.getLiterals().value_or(DenseIntElementsAttr());
2839+
spirv::SwitchOp::create(
2840+
opBuilder, switchOp.getLoc(), switchOp.getSelector(),
2841+
switchOp.getDefaultTarget(), blockArgs, literals,
2842+
switchOp.getTargets(), targetOperands);
2843+
switchOp.erase();
2844+
} else {
2845+
SuccessorRange targets = switchOp.getTargets();
2846+
auto it = llvm::find(targets, target);
2847+
assert(it != targets.end());
2848+
size_t index = std::distance(targets.begin(), it);
2849+
switchOp.getTargetOperandsMutable(index).assign(blockArgs);
2850+
}
28342851
} else {
28352852
return emitError(unknownLoc, "unimplemented terminator for Phi creation");
28362853
}

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1443,7 +1443,20 @@ LogicalResult Serializer::emitPhiForBlockArguments(Block *block) {
14431443
assert(branchCondOp.getFalseTarget() == block);
14441444
blockOperands = branchCondOp.getFalseTargetOperands();
14451445
}
1446-
1446+
assert(!blockOperands->empty() &&
1447+
"expected non-empty block operand range");
1448+
predecessors.emplace_back(spirvPredecessor, *blockOperands);
1449+
} else if (auto switchOp = dyn_cast<spirv::SwitchOp>(terminator)) {
1450+
std::optional<OperandRange> blockOperands;
1451+
if (block == switchOp.getDefaultTarget()) {
1452+
blockOperands = switchOp.getDefaultOperands();
1453+
} else {
1454+
SuccessorRange targets = switchOp.getTargets();
1455+
auto it = llvm::find(targets, block);
1456+
assert(it != targets.end());
1457+
size_t index = std::distance(targets.begin(), it);
1458+
blockOperands = switchOp.getTargetOperands(index);
1459+
}
14471460
assert(!blockOperands->empty() &&
14481461
"expected non-empty block operand range");
14491462
predecessors.emplace_back(spirvPredecessor, *blockOperands);

mlir/test/Target/SPIRV/selection.mlir

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,61 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
288288
spirv.EntryPoint "GLCompute" @main
289289
spirv.ExecutionMode @main "LocalSize", 1, 1, 1
290290
}
291+
292+
// -----
293+
294+
// Selection with switch and block operands
295+
296+
spirv.module Logical GLSL450 requires #spirv.vce<v1.5, [Shader], []> {
297+
// CHECK-LABEL: @selection_switch_operands
298+
spirv.func @selection_switch_operands(%selector : si32) "None" {
299+
%cst1 = spirv.Constant 1.000000e+00 : f32
300+
%vec0 = spirv.Undef : vector<3xf32>
301+
// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[0 : i32] : f32 into vector<3xf32>
302+
%vec1 = spirv.CompositeInsert %cst1, %vec0[0 : i32] : f32 into vector<3xf32>
303+
spirv.Branch ^bb1
304+
^bb1:
305+
// CHECK: {{%.*}} = spirv.mlir.selection -> vector<3xf32> {
306+
%vec4 = spirv.mlir.selection -> vector<3xf32> {
307+
// CHECK-NEXT: spirv.Switch {{%.*}} : si32, [
308+
// CHECK-NEXT: default: ^[[DEFAULT:.+]]({{%.*}} : vector<3xf32>),
309+
// CHECK-NEXT: 0: ^[[CASE0:.+]]({{%.*}} : vector<3xf32>),
310+
// CHECK-NEXT: 1: ^[[CASE1:.+]]({{%.*}} : vector<3xf32>)
311+
spirv.Switch %selector : si32, [
312+
default: ^bb3(%vec1 : vector<3xf32>),
313+
0: ^bb1(%vec1 : vector<3xf32>),
314+
1: ^bb2(%vec1 : vector<3xf32>)
315+
]
316+
// CHECK: ^[[CASE0]]({{%.*}}: vector<3xf32>)
317+
^bb1(%vecbb1: vector<3xf32>):
318+
%cst3 = spirv.Constant 3.000000e+00 : f32
319+
// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32>
320+
%vec2 = spirv.CompositeInsert %cst3, %vecbb1[1 : i32] : f32 into vector<3xf32>
321+
// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>)
322+
spirv.Branch ^bb3(%vec2 : vector<3xf32>)
323+
// CHECK-NEXT: ^[[CASE1]]({{%.*}}: vector<3xf32>)
324+
^bb2(%vecbb2: vector<3xf32>):
325+
%cst4 = spirv.Constant 4.000000e+00 : f32
326+
// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[1 : i32] : f32 into vector<3xf32>
327+
%vec3 = spirv.CompositeInsert %cst4, %vecbb2[1 : i32] : f32 into vector<3xf32>
328+
// CHECK-NEXT: spirv.Branch ^[[DEFAULT]]({{%.*}} : vector<3xf32>)
329+
spirv.Branch ^bb3(%vec3 : vector<3xf32>)
330+
// CHECK-NEXT: ^[[DEFAULT]]({{%.*}}: vector<3xf32>)
331+
^bb3(%vecbb3: vector<3xf32>):
332+
// CHECK-NEXT: spirv.mlir.merge {{%.*}} : vector<3xf32>
333+
spirv.mlir.merge %vecbb3 : vector<3xf32>
334+
// CHECK-NEXT: }
335+
}
336+
%cst2 = spirv.Constant 2.000000e+00 : f32
337+
// CHECK: {{%.*}} = spirv.CompositeInsert {{%.*}}, {{%.*}}[2 : i32] : f32 into vector<3xf32>
338+
%vec5 = spirv.CompositeInsert %cst2, %vec4[2 : i32] : f32 into vector<3xf32>
339+
spirv.Return
340+
}
341+
342+
spirv.func @main() -> () "None" {
343+
spirv.Return
344+
}
345+
346+
spirv.EntryPoint "GLCompute" @main
347+
spirv.ExecutionMode @main "LocalSize", 1, 1, 1
348+
}

0 commit comments

Comments
 (0)