Skip to content

Commit 891b3cf

Browse files
authored
[mlir][spirv] Add support for SwitchOp (#168713)
The dialect implementation mostly copies the one of `cf.switch`, but aligns naming to the SPIR-V spec.
1 parent 0e54667 commit 891b3cf

File tree

13 files changed

+669
-2
lines changed

13 files changed

+669
-2
lines changed

mlir/docs/Dialects/SPIR-V.md

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ merge block.
566566
For example, for the given function
567567

568568
```c++
569-
void loop(bool cond) {
569+
void if(bool cond) {
570570
int x = 0;
571571
if (cond) {
572572
x = 1;
@@ -605,6 +605,62 @@ func.func @selection(%cond: i1) -> () {
605605
}
606606
```
607607

608+
Similarly, for the give function with a `switch` statement
609+
610+
```c++
611+
void switch(int selector) {
612+
int x = 0;
613+
switch (selector) {
614+
case 0:
615+
x = 2;
616+
break;
617+
case 1:
618+
x = 3;
619+
break;
620+
default:
621+
x = 1;
622+
break;
623+
}
624+
// ...
625+
}
626+
```
627+
628+
It will be represented as
629+
630+
```mlir
631+
func.func @selection(%selector: i32) -> () {
632+
%zero = spirv.Constant 0: i32
633+
%one = spirv.Constant 1: i32
634+
%two = spirv.Constant 2: i32
635+
%three = spirv.Constant 3: i32
636+
%var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
637+
638+
spirv.mlir.selection {
639+
spirv.Switch %selector : i32, [
640+
default: ^default,
641+
0: ^case0,
642+
1: ^case1
643+
]
644+
^default:
645+
spirv.Store "Function" %var, %one : i32
646+
spirv.Branch ^merge
647+
648+
^case0:
649+
spirv.Store "Function" %var, %two : i32
650+
spirv.Branch ^merge
651+
652+
^case1:
653+
spirv.Store "Function" %var, %three : i32
654+
spirv.Branch ^merge
655+
656+
^merge:
657+
spirv.mlir.merge
658+
}
659+
660+
// ...
661+
}
662+
```
663+
608664
The selection can return values by yielding them with `spirv.mlir.merge`. This
609665
mechanism allows values defined within the selection region to be used outside of it.
610666
Without this, values that were sunk into the selection region, but used outside, would

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4531,6 +4531,7 @@ def SPIRV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerg
45314531
def SPIRV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
45324532
def SPIRV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
45334533
def SPIRV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
4534+
def SPIRV_OC_OpSwitch : I32EnumAttrCase<"OpSwitch", 251>;
45344535
def SPIRV_OC_OpKill : I32EnumAttrCase<"OpKill", 252>;
45354536
def SPIRV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
45364537
def SPIRV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
@@ -4681,7 +4682,7 @@ def SPIRV_OpcodeAttr :
46814682
SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor,
46824683
SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge,
46834684
SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional,
4684-
SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
4685+
SPIRV_OC_OpSwitch, SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
46854686
SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd,
46864687
SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
46874688
SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,112 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [
242242
}];
243243
}
244244

245+
// -----
246+
247+
def SPIRV_SwitchOp : SPIRV_Op<"Switch",
248+
[AttrSizedOperandSegments, InFunctionScope,
249+
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
250+
Pure, Terminator]> {
251+
let summary = [{
252+
Multi-way branch to one of the operand label <id>.
253+
}];
254+
255+
let description = [{
256+
Selector must have a type of OpTypeInt. Selector is compared for equality to
257+
the Target literals.
258+
259+
Default must be the <id> of a label. If Selector does not equal any of the
260+
Target literals, control flow branches to the Default label <id>.
261+
262+
Target must be alternating scalar integer literals and the <id> of a label.
263+
If Selector equals a literal, control flow branches to the following label
264+
<id>. It is invalid for any two literal to be equal to each other. If Selector
265+
does not equal any literal, control flow branches to the Default label <id>.
266+
Each literal is interpreted with the type of Selector: The bit width of
267+
Selector’s type is the width of each literal’s type. If this width is not a
268+
multiple of 32-bits and the OpTypeInt Signedness is set to 1, the literal values
269+
are interpreted as being sign extended.
270+
271+
If Selector is an OpUndef, behavior is undefined.
272+
273+
This instruction must be the last instruction in a block.
274+
275+
#### Example:
276+
277+
```mlir
278+
spirv.Switch %selector : si32, [
279+
default: ^bb1(%a : i32),
280+
0: ^bb1(%b : i32),
281+
1: ^bb3(%c : i32)
282+
]
283+
```
284+
}];
285+
286+
let arguments = (ins
287+
SPIRV_Integer:$selector,
288+
Variadic<AnyType>:$defaultOperands,
289+
VariadicOfVariadic<AnyType, "case_operand_segments">:$targetOperands,
290+
OptionalAttr<AnyIntElementsAttr>:$literals,
291+
DenseI32ArrayAttr:$case_operand_segments
292+
);
293+
294+
let results = (outs);
295+
296+
let successors = (successor AnySuccessor:$defaultTarget,
297+
VariadicSuccessor<AnySuccessor>:$targets);
298+
299+
let builders = [
300+
OpBuilder<(ins "Value":$selector,
301+
"Block *":$defaultTarget,
302+
"ValueRange":$defaultOperands,
303+
CArg<"ArrayRef<APInt>", "{}">:$literals,
304+
CArg<"BlockRange", "{}">:$targets,
305+
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
306+
OpBuilder<(ins "Value":$selector,
307+
"Block *":$defaultTarget,
308+
"ValueRange":$defaultOperands,
309+
CArg<"ArrayRef<int32_t>", "{}">:$literals,
310+
CArg<"BlockRange", "{}">:$targets,
311+
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
312+
OpBuilder<(ins "Value":$selector,
313+
"Block *":$defaultTarget,
314+
"ValueRange":$defaultOperands,
315+
CArg<"DenseIntElementsAttr", "{}">:$literals,
316+
CArg<"BlockRange", "{}">:$targets,
317+
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>
318+
];
319+
320+
let assemblyFormat = [{
321+
$selector `:` type($selector) `,` `[` `\n`
322+
custom<SwitchOpCases>(ref(type($selector)),$defaultTarget,
323+
$defaultOperands,
324+
type($defaultOperands),
325+
$literals,
326+
$targets,
327+
$targetOperands,
328+
type($targetOperands))
329+
`]`
330+
attr-dict
331+
}];
332+
333+
let extraClassDeclaration = [{
334+
/// Return the operands for the target block at the given index.
335+
OperandRange getTargetOperands(unsigned index) {
336+
return getTargetOperands()[index];
337+
}
338+
339+
/// Return a mutable range of operands for the target block at the
340+
/// given index.
341+
MutableOperandRange getTargetOperandsMutable(unsigned index) {
342+
return getTargetOperandsMutable()[index];
343+
}
344+
}];
345+
346+
let autogenSerialization = 0;
347+
let hasVerifier = 1;
348+
}
349+
350+
245351
// -----
246352

247353
def SPIRV_KillOp : SPIRV_Op<"Kill", [Terminator]> {

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
219219
return getArgumentsMutable();
220220
}
221221

222+
//===----------------------------------------------------------------------===//
223+
// spirv.Switch
224+
//===----------------------------------------------------------------------===//
225+
226+
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
227+
Block *defaultTarget, ValueRange defaultOperands,
228+
DenseIntElementsAttr literals, BlockRange targets,
229+
ArrayRef<ValueRange> targetOperands) {
230+
build(builder, result, selector, defaultOperands, targetOperands, literals,
231+
defaultTarget, targets);
232+
}
233+
234+
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
235+
Block *defaultTarget, ValueRange defaultOperands,
236+
ArrayRef<APInt> literals, BlockRange targets,
237+
ArrayRef<ValueRange> targetOperands) {
238+
DenseIntElementsAttr literalsAttr;
239+
if (!literals.empty()) {
240+
ShapedType literalType = VectorType::get(
241+
static_cast<int64_t>(literals.size()), selector.getType());
242+
literalsAttr = DenseIntElementsAttr::get(literalType, literals);
243+
}
244+
build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
245+
targets, targetOperands);
246+
}
247+
248+
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
249+
Block *defaultTarget, ValueRange defaultOperands,
250+
ArrayRef<int32_t> literals, BlockRange targets,
251+
ArrayRef<ValueRange> targetOperands) {
252+
DenseIntElementsAttr literalsAttr;
253+
if (!literals.empty()) {
254+
ShapedType literalType = VectorType::get(
255+
static_cast<int64_t>(literals.size()), selector.getType());
256+
literalsAttr = DenseIntElementsAttr::get(literalType, literals);
257+
}
258+
build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
259+
targets, targetOperands);
260+
}
261+
262+
LogicalResult SwitchOp::verify() {
263+
std::optional<DenseIntElementsAttr> literals = getLiterals();
264+
BlockRange targets = getTargets();
265+
266+
if (!literals && targets.empty())
267+
return success();
268+
269+
Type selectorType = getSelector().getType();
270+
Type literalType = literals->getType().getElementType();
271+
if (literalType != selectorType)
272+
return emitOpError() << "'selector' type (" << selectorType
273+
<< ") should match literals type (" << literalType
274+
<< ")";
275+
276+
if (literals && literals->size() != static_cast<int64_t>(targets.size()))
277+
return emitOpError() << "number of literals (" << literals->size()
278+
<< ") should match number of targets ("
279+
<< targets.size() << ")";
280+
return success();
281+
}
282+
283+
SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
284+
assert(index < getNumSuccessors() && "invalid successor index");
285+
return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
286+
: getTargetOperandsMutable(index - 1));
287+
}
288+
289+
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
290+
std::optional<DenseIntElementsAttr> literals = getLiterals();
291+
292+
if (!literals)
293+
return getDefaultTarget();
294+
295+
SuccessorRange targets = getTargets();
296+
if (auto value = dyn_cast_or_null<IntegerAttr>(operands.front())) {
297+
for (auto [index, literal] : llvm::enumerate(literals->getValues<APInt>()))
298+
if (literal == value.getValue())
299+
return targets[index];
300+
return getDefaultTarget();
301+
}
302+
return nullptr;
303+
}
304+
222305
//===----------------------------------------------------------------------===//
223306
// spirv.mlir.loop
224307
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
8181
}
8282
}
8383

84+
/// Adapted from the cf.switch implementation.
85+
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
86+
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
87+
static ParseResult parseSwitchOpCases(
88+
OpAsmParser &parser, Type &selectorType, Block *&defaultTarget,
89+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
90+
SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals,
91+
SmallVectorImpl<Block *> &targets,
92+
SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>>
93+
&targetOperands,
94+
SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) {
95+
if (parser.parseKeyword("default") || parser.parseColon() ||
96+
parser.parseSuccessor(defaultTarget))
97+
return failure();
98+
if (succeeded(parser.parseOptionalLParen())) {
99+
if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
100+
/*allowResultNumber=*/false) ||
101+
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
102+
return failure();
103+
}
104+
105+
SmallVector<APInt> values;
106+
unsigned bitWidth = selectorType.getIntOrFloatBitWidth();
107+
while (succeeded(parser.parseOptionalComma())) {
108+
int64_t value = 0;
109+
if (failed(parser.parseInteger(value)))
110+
return failure();
111+
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
112+
113+
Block *target;
114+
SmallVector<OpAsmParser::UnresolvedOperand> operands;
115+
SmallVector<Type> operandTypes;
116+
if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target)))
117+
return failure();
118+
if (succeeded(parser.parseOptionalLParen())) {
119+
if (failed(parser.parseOperandList(operands,
120+
OpAsmParser::Delimiter::None)) ||
121+
failed(parser.parseColonTypeList(operandTypes)) ||
122+
failed(parser.parseRParen()))
123+
return failure();
124+
}
125+
targets.push_back(target);
126+
targetOperands.emplace_back(operands);
127+
targetOperandTypes.emplace_back(operandTypes);
128+
}
129+
130+
if (!values.empty()) {
131+
ShapedType literalType =
132+
VectorType::get(static_cast<int64_t>(values.size()), selectorType);
133+
literals = DenseIntElementsAttr::get(literalType, values);
134+
}
135+
return success();
136+
}
137+
138+
static void
139+
printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType,
140+
Block *defaultTarget, OperandRange defaultOperands,
141+
TypeRange defaultOperandTypes, DenseIntElementsAttr literals,
142+
SuccessorRange targets, OperandRangeRange targetOperands,
143+
const TypeRangeRange &targetOperandTypes) {
144+
p << " default: ";
145+
p.printSuccessorAndUseList(defaultTarget, defaultOperands);
146+
147+
if (!literals)
148+
return;
149+
150+
for (auto [index, literal] : llvm::enumerate(literals.getValues<APInt>())) {
151+
p << ',';
152+
p.printNewline();
153+
p << " ";
154+
p << literal.getLimitedValue();
155+
p << ": ";
156+
p.printSuccessorAndUseList(targets[index], targetOperands[index]);
157+
}
158+
p.printNewline();
159+
}
160+
84161
} // namespace mlir::spirv
85162

86163
// TablenGen'erated operation definitions.

mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ LogicalResult spirv::Deserializer::processInstruction(
248248
return processLoopMerge(operands);
249249
case spirv::Opcode::OpPhi:
250250
return processPhi(operands);
251+
case spirv::Opcode::OpSwitch:
252+
return processSwitch(operands);
251253
case spirv::Opcode::OpUndef:
252254
return processUndef(operands);
253255
default:

0 commit comments

Comments
 (0)