Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion mlir/docs/Dialects/SPIR-V.md
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ merge block.
For example, for the given function

```c++
void loop(bool cond) {
void if(bool cond) {
int x = 0;
if (cond) {
x = 1;
Expand Down Expand Up @@ -605,6 +605,62 @@ func.func @selection(%cond: i1) -> () {
}
```

Similarly, for the give function with a `switch` statement

```c++
void switch(int selector) {
int x = 0;
switch (selector) {
case 0:
x = 2;
break;
case 1:
x = 3;
break;
default:
x = 1;
break;
}
// ...
}
```

It will be represented as

```mlir
func.func @selection(%selector: i32) -> () {
%zero = spirv.Constant 0: i32
%one = spirv.Constant 1: i32
%two = spirv.Constant 2: i32
%three = spirv.Constant 3: i32
%var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>

spirv.mlir.selection {
spirv.Switch %selector : i32, [
default: ^default,
0: ^case0,
1: ^case1
]
^default:
spirv.Store "Function" %var, %one : i32
spirv.Branch ^merge

^case0:
spirv.Store "Function" %var, %two : i32
spirv.Branch ^merge

^case1:
spirv.Store "Function" %var, %three : i32
spirv.Branch ^merge

^merge:
spirv.mlir.merge
}

// ...
}
```

The selection can return values by yielding them with `spirv.mlir.merge`. This
mechanism allows values defined within the selection region to be used outside of it.
Without this, values that were sunk into the selection region, but used outside, would
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -4531,6 +4531,7 @@ def SPIRV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerg
def SPIRV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
def SPIRV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
def SPIRV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
def SPIRV_OC_OpSwitch : I32EnumAttrCase<"OpSwitch", 251>;
def SPIRV_OC_OpKill : I32EnumAttrCase<"OpKill", 252>;
def SPIRV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPIRV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
Expand Down Expand Up @@ -4681,7 +4682,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor,
SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge,
SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional,
SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
SPIRV_OC_OpSwitch, SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd,
SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
Expand Down
106 changes: 106 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,112 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [
}];
}

// -----

def SPIRV_SwitchOp : SPIRV_Op<"Switch",
[AttrSizedOperandSegments, InFunctionScope,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
Pure, Terminator]> {
let summary = [{
Multi-way branch to one of the operand label <id>.
}];

let description = [{
Selector must have a type of OpTypeInt. Selector is compared for equality to
the Target literals.

Default must be the <id> of a label. If Selector does not equal any of the
Target literals, control flow branches to the Default label <id>.

Target must be alternating scalar integer literals and the <id> of a label.
If Selector equals a literal, control flow branches to the following label
<id>. It is invalid for any two literal to be equal to each other. If Selector
does not equal any literal, control flow branches to the Default label <id>.
Each literal is interpreted with the type of Selector: The bit width of
Selector’s type is the width of each literal’s type. If this width is not a
multiple of 32-bits and the OpTypeInt Signedness is set to 1, the literal values
are interpreted as being sign extended.

If Selector is an OpUndef, behavior is undefined.

This instruction must be the last instruction in a block.

#### Example:

```mlir
spirv.Switch %selector : si32, [
default: ^bb1(%a : i32),
0: ^bb1(%b : i32),
1: ^bb3(%c : i32)
]
```
}];

let arguments = (ins
SPIRV_Integer:$selector,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$targetOperands,
OptionalAttr<AnyIntElementsAttr>:$literals,
DenseI32ArrayAttr:$case_operand_segments
);

let results = (outs);

let successors = (successor AnySuccessor:$defaultTarget,
VariadicSuccessor<AnySuccessor>:$targets);

let builders = [
OpBuilder<(ins "Value":$selector,
"Block *":$defaultTarget,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<APInt>", "{}">:$literals,
CArg<"BlockRange", "{}">:$targets,
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
OpBuilder<(ins "Value":$selector,
"Block *":$defaultTarget,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<int32_t>", "{}">:$literals,
CArg<"BlockRange", "{}">:$targets,
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
OpBuilder<(ins "Value":$selector,
"Block *":$defaultTarget,
"ValueRange":$defaultOperands,
CArg<"DenseIntElementsAttr", "{}">:$literals,
CArg<"BlockRange", "{}">:$targets,
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>
];

let assemblyFormat = [{
$selector `:` type($selector) `,` `[` `\n`
custom<SwitchOpCases>(ref(type($selector)),$defaultTarget,
$defaultOperands,
type($defaultOperands),
$literals,
$targets,
$targetOperands,
type($targetOperands))
`]`
attr-dict
}];

let extraClassDeclaration = [{
/// Return the operands for the target block at the given index.
OperandRange getTargetOperands(unsigned index) {
return getTargetOperands()[index];
}

/// Return a mutable range of operands for the target block at the
/// given index.
MutableOperandRange getTargetOperandsMutable(unsigned index) {
return getTargetOperandsMutable()[index];
}
}];

let autogenSerialization = 0;
let hasVerifier = 1;
}


// -----

def SPIRV_KillOp : SPIRV_Op<"Kill", [Terminator]> {
Expand Down
83 changes: 83 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
return getArgumentsMutable();
}

//===----------------------------------------------------------------------===//
// spirv.Switch
//===----------------------------------------------------------------------===//

void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
Block *defaultTarget, ValueRange defaultOperands,
DenseIntElementsAttr literals, BlockRange targets,
ArrayRef<ValueRange> targetOperands) {
build(builder, result, selector, defaultOperands, targetOperands, literals,
defaultTarget, targets);
}

void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
Block *defaultTarget, ValueRange defaultOperands,
ArrayRef<APInt> literals, BlockRange targets,
ArrayRef<ValueRange> targetOperands) {
DenseIntElementsAttr literalsAttr;
if (!literals.empty()) {
ShapedType literalType = VectorType::get(
static_cast<int64_t>(literals.size()), selector.getType());
literalsAttr = DenseIntElementsAttr::get(literalType, literals);
}
build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
targets, targetOperands);
}

void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
Block *defaultTarget, ValueRange defaultOperands,
ArrayRef<int32_t> literals, BlockRange targets,
ArrayRef<ValueRange> targetOperands) {
DenseIntElementsAttr literalsAttr;
if (!literals.empty()) {
ShapedType literalType = VectorType::get(
static_cast<int64_t>(literals.size()), selector.getType());
literalsAttr = DenseIntElementsAttr::get(literalType, literals);
}
build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
targets, targetOperands);
}

LogicalResult SwitchOp::verify() {
std::optional<DenseIntElementsAttr> literals = getLiterals();
BlockRange targets = getTargets();

if (!literals && targets.empty())
return success();

Type selectorType = getSelector().getType();
Type literalType = literals->getType().getElementType();
if (literalType != selectorType)
return emitOpError() << "'selector' type (" << selectorType
<< ") should match literals type (" << literalType
<< ")";

if (literals && literals->size() != static_cast<int64_t>(targets.size()))
return emitOpError() << "number of literals (" << literals->size()
<< ") should match number of targets ("
<< targets.size() << ")";
return success();
}

SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
: getTargetOperandsMutable(index - 1));
}

Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
std::optional<DenseIntElementsAttr> literals = getLiterals();

if (!literals)
return getDefaultTarget();

SuccessorRange targets = getTargets();
if (auto value = dyn_cast_or_null<IntegerAttr>(operands.front())) {
for (auto [index, literal] : llvm::enumerate(literals->getValues<APInt>()))
if (literal == value.getValue())
return targets[index];
return getDefaultTarget();
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// spirv.mlir.loop
//===----------------------------------------------------------------------===//
Expand Down
77 changes: 77 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
}
}

/// Adapted from the cf.switch implementation.
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
static ParseResult parseSwitchOpCases(
OpAsmParser &parser, Type &selectorType, Block *&defaultTarget,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals,
SmallVectorImpl<Block *> &targets,
SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>>
&targetOperands,
SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) {
if (parser.parseKeyword("default") || parser.parseColon() ||
parser.parseSuccessor(defaultTarget))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
/*allowResultNumber=*/false) ||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
return failure();
}

SmallVector<APInt> values;
unsigned bitWidth = selectorType.getIntOrFloatBitWidth();
while (succeeded(parser.parseOptionalComma())) {
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));

Block *target;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
SmallVector<Type> operandTypes;
if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseOperandList(operands,
OpAsmParser::Delimiter::None)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
targets.push_back(target);
targetOperands.emplace_back(operands);
targetOperandTypes.emplace_back(operandTypes);
}

if (!values.empty()) {
ShapedType literalType =
VectorType::get(static_cast<int64_t>(values.size()), selectorType);
literals = DenseIntElementsAttr::get(literalType, values);
}
return success();
}

static void
printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType,
Block *defaultTarget, OperandRange defaultOperands,
TypeRange defaultOperandTypes, DenseIntElementsAttr literals,
SuccessorRange targets, OperandRangeRange targetOperands,
const TypeRangeRange &targetOperandTypes) {
p << " default: ";
p.printSuccessorAndUseList(defaultTarget, defaultOperands);

if (!literals)
return;

for (auto [index, literal] : llvm::enumerate(literals.getValues<APInt>())) {
p << ',';
p.printNewline();
p << " ";
p << literal.getLimitedValue();
p << ": ";
p.printSuccessorAndUseList(targets[index], targetOperands[index]);
}
p.printNewline();
}

} // namespace mlir::spirv

// TablenGen'erated operation definitions.
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ LogicalResult spirv::Deserializer::processInstruction(
return processLoopMerge(operands);
case spirv::Opcode::OpPhi:
return processPhi(operands);
case spirv::Opcode::OpSwitch:
return processSwitch(operands);
case spirv::Opcode::OpUndef:
return processUndef(operands);
default:
Expand Down
Loading