Skip to content

Commit 82d3280

Browse files
committed
[mlir][spirv] Add support for SwitchOp
The dialect implementation mostly copies the one of `cf.switch`, but aligns naming to the SPIR-V spec.
1 parent 3b83e7f commit 82d3280

File tree

12 files changed

+565
-1
lines changed

12 files changed

+565
-1
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4531,6 +4531,7 @@ def SPIRV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerg
45314531
def SPIRV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
45324532
def SPIRV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
45334533
def SPIRV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
4534+
def SPIRV_OC_OpSwitch : I32EnumAttrCase<"OpSwitch", 251>;
45344535
def SPIRV_OC_OpKill : I32EnumAttrCase<"OpKill", 252>;
45354536
def SPIRV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
45364537
def SPIRV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
@@ -4681,7 +4682,7 @@ def SPIRV_OpcodeAttr :
46814682
SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor,
46824683
SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge,
46834684
SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional,
4684-
SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
4685+
SPIRV_OC_OpSwitch, SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
46854686
SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd,
46864687
SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
46874688
SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,112 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [
242242
}];
243243
}
244244

245+
// -----
246+
247+
def SPIRV_SwitchOp : SPIRV_Op<"Switch",
248+
[AttrSizedOperandSegments, InFunctionScope,
249+
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
250+
Pure, Terminator]> {
251+
let summary = [{
252+
Multi-way branch to one of the operand label <id>.
253+
}];
254+
255+
let description = [{
256+
Selector must have a type of OpTypeInt. Selector is compared for equality to
257+
the Target literals.
258+
259+
Default must be the <id> of a label. If Selector does not equal any of the
260+
Target literals, control flow branches to the Default label <id>.
261+
262+
Target must be alternating scalar integer literals and the <id> of a label.
263+
If Selector equals a literal, control flow branches to the following label
264+
<id>. It is invalid for any two literal to be equal to each other. If Selector
265+
does not equal any literal, control flow branches to the Default label <id>.
266+
Each literal is interpreted with the type of Selector: The bit width of
267+
Selector’s type is the width of each literal’s type. If this width is not a
268+
multiple of 32-bits and the OpTypeInt Signedness is set to 1, the literal values
269+
are interpreted as being sign extended.
270+
271+
If Selector is an OpUndef, behavior is undefined.
272+
273+
This instruction must be the last instruction in a block.
274+
275+
#### Example:
276+
277+
```mlir
278+
spirv.Switch %selector : si32, [
279+
default: ^bb1(%a : i32),
280+
0: ^bb1(%b : i32),
281+
1: ^bb3(%c : i32)
282+
]
283+
```
284+
}];
285+
286+
let arguments = (ins
287+
SPIRV_Integer:$selector,
288+
Variadic<AnyType>:$defaultOperands,
289+
VariadicOfVariadic<AnyType, "case_operand_segments">:$targetOperands,
290+
OptionalAttr<AnyIntElementsAttr>:$literals,
291+
DenseI32ArrayAttr:$case_operand_segments
292+
);
293+
294+
let results = (outs);
295+
296+
let successors = (successor AnySuccessor:$defaultTarget,
297+
VariadicSuccessor<AnySuccessor>:$targets);
298+
299+
let builders = [
300+
OpBuilder<(ins "Value":$selector,
301+
"Block *":$defaultTarget,
302+
"ValueRange":$defaultOperands,
303+
CArg<"ArrayRef<APInt>", "{}">:$literals,
304+
CArg<"BlockRange", "{}">:$targets,
305+
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
306+
OpBuilder<(ins "Value":$selector,
307+
"Block *":$defaultTarget,
308+
"ValueRange":$defaultOperands,
309+
CArg<"ArrayRef<int32_t>", "{}">:$literals,
310+
CArg<"BlockRange", "{}">:$targets,
311+
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
312+
OpBuilder<(ins "Value":$selector,
313+
"Block *":$defaultTarget,
314+
"ValueRange":$defaultOperands,
315+
CArg<"DenseIntElementsAttr", "{}">:$literals,
316+
CArg<"BlockRange", "{}">:$targets,
317+
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>
318+
];
319+
320+
let assemblyFormat = [{
321+
$selector `:` type($selector) `,` `[` `\n`
322+
custom<SwitchOpCases>(ref(type($selector)),$defaultTarget,
323+
$defaultOperands,
324+
type($defaultOperands),
325+
$literals,
326+
$targets,
327+
$targetOperands,
328+
type($targetOperands))
329+
`]`
330+
attr-dict
331+
}];
332+
333+
let extraClassDeclaration = [{
334+
/// Return the operands for the target block at the given index.
335+
OperandRange getTargetOperands(unsigned index) {
336+
return getTargetOperands()[index];
337+
}
338+
339+
/// Return a mutable range of operands for the target block at the
340+
/// given index.
341+
MutableOperandRange getTargetOperandsMutable(unsigned index) {
342+
return getTargetOperandsMutable()[index];
343+
}
344+
}];
345+
346+
let autogenSerialization = 0;
347+
let hasVerifier = 1;
348+
}
349+
350+
245351
// -----
246352

247353
def SPIRV_KillOp : SPIRV_Op<"Kill", [Terminator]> {

mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
219219
return getArgumentsMutable();
220220
}
221221

222+
//===----------------------------------------------------------------------===//
223+
// spirv.Switch
224+
//===----------------------------------------------------------------------===//
225+
226+
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
227+
Block *defaultTarget, ValueRange defaultOperands,
228+
DenseIntElementsAttr literals, BlockRange targets,
229+
ArrayRef<ValueRange> targetOperands) {
230+
build(builder, result, selector, defaultOperands, targetOperands, literals,
231+
defaultTarget, targets);
232+
}
233+
234+
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
235+
Block *defaultTarget, ValueRange defaultOperands,
236+
ArrayRef<APInt> literals, BlockRange targets,
237+
ArrayRef<ValueRange> targetOperands) {
238+
DenseIntElementsAttr literalsAttr;
239+
if (!literals.empty()) {
240+
ShapedType literalType = VectorType::get(
241+
static_cast<int64_t>(literals.size()), selector.getType());
242+
literalsAttr = DenseIntElementsAttr::get(literalType, literals);
243+
}
244+
build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
245+
targets, targetOperands);
246+
}
247+
248+
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
249+
Block *defaultTarget, ValueRange defaultOperands,
250+
ArrayRef<int32_t> literals, BlockRange targets,
251+
ArrayRef<ValueRange> targetOperands) {
252+
DenseIntElementsAttr literalsAttr;
253+
if (!literals.empty()) {
254+
ShapedType literalType = VectorType::get(
255+
static_cast<int64_t>(literals.size()), selector.getType());
256+
literalsAttr = DenseIntElementsAttr::get(literalType, literals);
257+
}
258+
build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
259+
targets, targetOperands);
260+
}
261+
262+
LogicalResult SwitchOp::verify() {
263+
std::optional<DenseIntElementsAttr> literals = getLiterals();
264+
BlockRange targets = getTargets();
265+
266+
if (!literals && targets.empty())
267+
return success();
268+
269+
Type selectorType = getSelector().getType();
270+
Type literalType = literals->getType().getElementType();
271+
if (literalType != selectorType)
272+
return emitOpError() << "'selector' type (" << selectorType
273+
<< ") should match literals type (" << literalType
274+
<< ")";
275+
276+
if (literals && literals->size() != static_cast<int64_t>(targets.size()))
277+
return emitOpError() << "number of literals (" << literals->size()
278+
<< ") should match number of targets ("
279+
<< targets.size() << ")";
280+
return success();
281+
}
282+
283+
SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
284+
assert(index < getNumSuccessors() && "invalid successor index");
285+
return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
286+
: getTargetOperandsMutable(index - 1));
287+
}
288+
289+
Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
290+
std::optional<DenseIntElementsAttr> literals = getLiterals();
291+
292+
if (!literals)
293+
return getDefaultTarget();
294+
295+
SuccessorRange targets = getTargets();
296+
if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
297+
for (const auto &it : llvm::enumerate(literals->getValues<APInt>()))
298+
if (it.value() == value.getValue())
299+
return targets[it.index()];
300+
return getDefaultTarget();
301+
}
302+
return nullptr;
303+
}
304+
222305
//===----------------------------------------------------------------------===//
223306
// spirv.mlir.loop
224307
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
8181
}
8282
}
8383

84+
/// Adapted from the cf.switch implementation.
85+
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
86+
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
87+
static ParseResult parseSwitchOpCases(
88+
OpAsmParser &parser, Type &selectorType, Block *&defaultTarget,
89+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
90+
SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals,
91+
SmallVectorImpl<Block *> &targets,
92+
SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>>
93+
&targetOperands,
94+
SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) {
95+
if (parser.parseKeyword("default") || parser.parseColon() ||
96+
parser.parseSuccessor(defaultTarget))
97+
return failure();
98+
if (succeeded(parser.parseOptionalLParen())) {
99+
if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
100+
/*allowResultNumber=*/false) ||
101+
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
102+
return failure();
103+
}
104+
105+
SmallVector<APInt> values;
106+
unsigned bitWidth = selectorType.getIntOrFloatBitWidth();
107+
while (succeeded(parser.parseOptionalComma())) {
108+
int64_t value = 0;
109+
if (failed(parser.parseInteger(value)))
110+
return failure();
111+
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));
112+
113+
Block *target;
114+
SmallVector<OpAsmParser::UnresolvedOperand> operands;
115+
SmallVector<Type> operandTypes;
116+
if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target)))
117+
return failure();
118+
if (succeeded(parser.parseOptionalLParen())) {
119+
if (failed(parser.parseOperandList(operands,
120+
OpAsmParser::Delimiter::None)) ||
121+
failed(parser.parseColonTypeList(operandTypes)) ||
122+
failed(parser.parseRParen()))
123+
return failure();
124+
}
125+
targets.push_back(target);
126+
targetOperands.emplace_back(operands);
127+
targetOperandTypes.emplace_back(operandTypes);
128+
}
129+
130+
if (!values.empty()) {
131+
ShapedType literalType =
132+
VectorType::get(static_cast<int64_t>(values.size()), selectorType);
133+
literals = DenseIntElementsAttr::get(literalType, values);
134+
}
135+
return success();
136+
}
137+
138+
static void
139+
printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType,
140+
Block *defaultTarget, OperandRange defaultOperands,
141+
TypeRange defaultOperandTypes, DenseIntElementsAttr literals,
142+
SuccessorRange targets, OperandRangeRange targetOperands,
143+
const TypeRangeRange &targetOperandTypes) {
144+
p << " default: ";
145+
p.printSuccessorAndUseList(defaultTarget, defaultOperands);
146+
147+
if (!literals)
148+
return;
149+
150+
for (const auto &it : llvm::enumerate(literals.getValues<APInt>())) {
151+
p << ',';
152+
p.printNewline();
153+
p << " ";
154+
p << it.value().getLimitedValue();
155+
p << ": ";
156+
p.printSuccessorAndUseList(targets[it.index()], targetOperands[it.index()]);
157+
}
158+
p.printNewline();
159+
}
160+
84161
} // namespace mlir::spirv
85162

86163
// TablenGen'erated operation definitions.

mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,8 @@ LogicalResult spirv::Deserializer::processInstruction(
248248
return processLoopMerge(operands);
249249
case spirv::Opcode::OpPhi:
250250
return processPhi(operands);
251+
case spirv::Opcode::OpSwitch:
252+
return processSwitch(operands);
251253
case spirv::Opcode::OpUndef:
252254
return processUndef(operands);
253255
default:

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2292,6 +2292,38 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
22922292
return success();
22932293
}
22942294

2295+
LogicalResult spirv::Deserializer::processSwitch(ArrayRef<uint32_t> operands) {
2296+
if (!curBlock)
2297+
return emitError(unknownLoc, "OpSwitch must appear in a block");
2298+
2299+
if (operands.size() < 2)
2300+
return emitError(unknownLoc, "OpSwitch must at least specify selector and "
2301+
"a default target");
2302+
2303+
if (operands.size() % 2)
2304+
return emitError(unknownLoc,
2305+
"OpSwitch must at have an even number of operands: "
2306+
"selector, default target and any number of literal and "
2307+
"label <id> pairs");
2308+
2309+
Value selector = getValue(operands[0]);
2310+
Block *defaultBlock = getOrCreateBlock(operands[1]);
2311+
Location loc = createFileLineColLoc(opBuilder);
2312+
2313+
SmallVector<int32_t> literals;
2314+
SmallVector<Block *> blocks;
2315+
for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
2316+
literals.push_back(operands[i]);
2317+
blocks.push_back(getOrCreateBlock(operands[i + 1]));
2318+
}
2319+
2320+
SmallVector<ValueRange> targetOperands(blocks.size(), {});
2321+
spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
2322+
ArrayRef<Value>(), literals, blocks, targetOperands);
2323+
2324+
return success();
2325+
}
2326+
22952327
namespace {
22962328
/// A class for putting all blocks in a structured selection/loop in a
22972329
/// spirv.mlir.selection/spirv.mlir.loop op.

mlir/lib/Target/SPIRV/Deserialization/Deserializer.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,9 @@ class Deserializer {
472472
/// Processes a SPIR-V OpPhi instruction with the given `operands`.
473473
LogicalResult processPhi(ArrayRef<uint32_t> operands);
474474

475+
/// Processes a SPIR-V OpSwitch instruction with the given `operands`.
476+
LogicalResult processSwitch(ArrayRef<uint32_t> operands);
477+
475478
/// Creates block arguments on predecessors previously recorded when handling
476479
/// OpPhi instructions.
477480
LogicalResult wireUpBlockArgument();

mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -775,6 +775,27 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
775775
return success();
776776
}
777777

778+
LogicalResult Serializer::processSwitchOp(spirv::SwitchOp switchOp) {
779+
uint32_t selectorID = getValueID(switchOp.getSelector());
780+
uint32_t defaultLabelID = getOrCreateBlockID(switchOp.getDefaultTarget());
781+
SmallVector<uint32_t> arguments{selectorID, defaultLabelID};
782+
783+
std::optional<mlir::DenseIntElementsAttr> literals = switchOp.getLiterals();
784+
BlockRange targets = switchOp.getTargets();
785+
if (literals) {
786+
for (auto [literal, target] : llvm::zip(*literals, targets)) {
787+
arguments.push_back(literal.getLimitedValue());
788+
uint32_t targetLabelID = getOrCreateBlockID(target);
789+
arguments.push_back(targetLabelID);
790+
}
791+
}
792+
793+
if (failed(emitDebugLine(functionBody, switchOp.getLoc())))
794+
return failure();
795+
encodeInstructionInto(functionBody, spirv::Opcode::OpSwitch, arguments);
796+
return success();
797+
}
798+
778799
LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
779800
auto varName = addressOfOp.getVariable();
780801
auto variableID = getVariableID(varName);

mlir/lib/Target/SPIRV/Serialization/Serializer.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
15791579
.Case([&](spirv::SpecConstantOperationOp op) {
15801580
return processSpecConstantOperationOp(op);
15811581
})
1582+
.Case([&](spirv::SwitchOp op) { return processSwitchOp(op); })
15821583
.Case([&](spirv::UndefOp op) { return processUndefOp(op); })
15831584
.Case([&](spirv::VariableOp op) { return processVariableOp(op); })
15841585

0 commit comments

Comments
 (0)