Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -4531,6 +4531,7 @@ def SPIRV_OC_OpSelectionMerge : I32EnumAttrCase<"OpSelectionMerg
def SPIRV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
def SPIRV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
def SPIRV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
def SPIRV_OC_OpSwitch : I32EnumAttrCase<"OpSwitch", 251>;
def SPIRV_OC_OpKill : I32EnumAttrCase<"OpKill", 252>;
def SPIRV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPIRV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
Expand Down Expand Up @@ -4681,7 +4682,7 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpAtomicAnd, SPIRV_OC_OpAtomicOr, SPIRV_OC_OpAtomicXor,
SPIRV_OC_OpPhi, SPIRV_OC_OpLoopMerge, SPIRV_OC_OpSelectionMerge,
SPIRV_OC_OpLabel, SPIRV_OC_OpBranch, SPIRV_OC_OpBranchConditional,
SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
SPIRV_OC_OpSwitch, SPIRV_OC_OpKill, SPIRV_OC_OpReturn, SPIRV_OC_OpReturnValue,
SPIRV_OC_OpUnreachable, SPIRV_OC_OpGroupBroadcast, SPIRV_OC_OpGroupIAdd,
SPIRV_OC_OpGroupFAdd, SPIRV_OC_OpGroupFMin, SPIRV_OC_OpGroupUMin,
SPIRV_OC_OpGroupSMin, SPIRV_OC_OpGroupFMax, SPIRV_OC_OpGroupUMax,
Expand Down
106 changes: 106 additions & 0 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,112 @@ def SPIRV_FunctionCallOp : SPIRV_Op<"FunctionCall", [
}];
}

// -----

def SPIRV_SwitchOp : SPIRV_Op<"Switch",
[AttrSizedOperandSegments, InFunctionScope,
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
Pure, Terminator]> {
let summary = [{
Multi-way branch to one of the operand label <id>.
}];

let description = [{
Selector must have a type of OpTypeInt. Selector is compared for equality to
the Target literals.

Default must be the <id> of a label. If Selector does not equal any of the
Target literals, control flow branches to the Default label <id>.

Target must be alternating scalar integer literals and the <id> of a label.
If Selector equals a literal, control flow branches to the following label
<id>. It is invalid for any two literal to be equal to each other. If Selector
does not equal any literal, control flow branches to the Default label <id>.
Each literal is interpreted with the type of Selector: The bit width of
Selector’s type is the width of each literal’s type. If this width is not a
multiple of 32-bits and the OpTypeInt Signedness is set to 1, the literal values
are interpreted as being sign extended.

If Selector is an OpUndef, behavior is undefined.

This instruction must be the last instruction in a block.

#### Example:

```mlir
spirv.Switch %selector : si32, [
default: ^bb1(%a : i32),
0: ^bb1(%b : i32),
1: ^bb3(%c : i32)
]
```
}];

let arguments = (ins
SPIRV_Integer:$selector,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$targetOperands,
OptionalAttr<AnyIntElementsAttr>:$literals,
DenseI32ArrayAttr:$case_operand_segments
);

let results = (outs);

let successors = (successor AnySuccessor:$defaultTarget,
VariadicSuccessor<AnySuccessor>:$targets);

let builders = [
OpBuilder<(ins "Value":$selector,
"Block *":$defaultTarget,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<APInt>", "{}">:$literals,
CArg<"BlockRange", "{}">:$targets,
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
OpBuilder<(ins "Value":$selector,
"Block *":$defaultTarget,
"ValueRange":$defaultOperands,
CArg<"ArrayRef<int32_t>", "{}">:$literals,
CArg<"BlockRange", "{}">:$targets,
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>,
OpBuilder<(ins "Value":$selector,
"Block *":$defaultTarget,
"ValueRange":$defaultOperands,
CArg<"DenseIntElementsAttr", "{}">:$literals,
CArg<"BlockRange", "{}">:$targets,
CArg<"ArrayRef<ValueRange>", "{}">:$targetOperands)>
];

let assemblyFormat = [{
$selector `:` type($selector) `,` `[` `\n`
custom<SwitchOpCases>(ref(type($selector)),$defaultTarget,
$defaultOperands,
type($defaultOperands),
$literals,
$targets,
$targetOperands,
type($targetOperands))
`]`
attr-dict
}];

let extraClassDeclaration = [{
/// Return the operands for the target block at the given index.
OperandRange getTargetOperands(unsigned index) {
return getTargetOperands()[index];
}

/// Return a mutable range of operands for the target block at the
/// given index.
MutableOperandRange getTargetOperandsMutable(unsigned index) {
return getTargetOperandsMutable()[index];
}
}];

let autogenSerialization = 0;
let hasVerifier = 1;
}


// -----

def SPIRV_KillOp : SPIRV_Op<"Kill", [Terminator]> {
Expand Down
83 changes: 83 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,89 @@ MutableOperandRange FunctionCallOp::getArgOperandsMutable() {
return getArgumentsMutable();
}

//===----------------------------------------------------------------------===//
// spirv.Switch
//===----------------------------------------------------------------------===//

void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
Block *defaultTarget, ValueRange defaultOperands,
DenseIntElementsAttr literals, BlockRange targets,
ArrayRef<ValueRange> targetOperands) {
build(builder, result, selector, defaultOperands, targetOperands, literals,
defaultTarget, targets);
}

void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
Block *defaultTarget, ValueRange defaultOperands,
ArrayRef<APInt> literals, BlockRange targets,
ArrayRef<ValueRange> targetOperands) {
DenseIntElementsAttr literalsAttr;
if (!literals.empty()) {
ShapedType literalType = VectorType::get(
static_cast<int64_t>(literals.size()), selector.getType());
literalsAttr = DenseIntElementsAttr::get(literalType, literals);
}
build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
targets, targetOperands);
}

void SwitchOp::build(OpBuilder &builder, OperationState &result, Value selector,
Block *defaultTarget, ValueRange defaultOperands,
ArrayRef<int32_t> literals, BlockRange targets,
ArrayRef<ValueRange> targetOperands) {
DenseIntElementsAttr literalsAttr;
if (!literals.empty()) {
ShapedType literalType = VectorType::get(
static_cast<int64_t>(literals.size()), selector.getType());
literalsAttr = DenseIntElementsAttr::get(literalType, literals);
}
build(builder, result, selector, defaultTarget, defaultOperands, literalsAttr,
targets, targetOperands);
}

LogicalResult SwitchOp::verify() {
std::optional<DenseIntElementsAttr> literals = getLiterals();
BlockRange targets = getTargets();

if (!literals && targets.empty())
return success();

Type selectorType = getSelector().getType();
Type literalType = literals->getType().getElementType();
if (literalType != selectorType)
return emitOpError() << "'selector' type (" << selectorType
<< ") should match literals type (" << literalType
<< ")";

if (literals && literals->size() != static_cast<int64_t>(targets.size()))
return emitOpError() << "number of literals (" << literals->size()
<< ") should match number of targets ("
<< targets.size() << ")";
return success();
}

SuccessorOperands SwitchOp::getSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");
return SuccessorOperands(index == 0 ? getDefaultOperandsMutable()
: getTargetOperandsMutable(index - 1));
}

Block *SwitchOp::getSuccessorForOperands(ArrayRef<Attribute> operands) {
std::optional<DenseIntElementsAttr> literals = getLiterals();

if (!literals)
return getDefaultTarget();

SuccessorRange targets = getTargets();
if (auto value = llvm::dyn_cast_or_null<IntegerAttr>(operands.front())) {
for (auto [index, literal] : llvm::enumerate(literals->getValues<APInt>()))
if (literal == value.getValue())
return targets[index];
return getDefaultTarget();
}
return nullptr;
}

//===----------------------------------------------------------------------===//
// spirv.mlir.loop
//===----------------------------------------------------------------------===//
Expand Down
77 changes: 77 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOpDefinition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,83 @@ static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
}
}

/// Adapted from the cf.switch implementation.
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
static ParseResult parseSwitchOpCases(
OpAsmParser &parser, Type &selectorType, Block *&defaultTarget,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &defaultOperands,
SmallVectorImpl<Type> &defaultOperandTypes, DenseIntElementsAttr &literals,
SmallVectorImpl<Block *> &targets,
SmallVectorImpl<SmallVector<OpAsmParser::UnresolvedOperand>>
&targetOperands,
SmallVectorImpl<SmallVector<Type>> &targetOperandTypes) {
if (parser.parseKeyword("default") || parser.parseColon() ||
parser.parseSuccessor(defaultTarget))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (parser.parseOperandList(defaultOperands, OpAsmParser::Delimiter::None,
/*allowResultNumber=*/false) ||
parser.parseColonTypeList(defaultOperandTypes) || parser.parseRParen())
return failure();
}

SmallVector<APInt> values;
unsigned bitWidth = selectorType.getIntOrFloatBitWidth();
while (succeeded(parser.parseOptionalComma())) {
int64_t value = 0;
if (failed(parser.parseInteger(value)))
return failure();
values.push_back(APInt(bitWidth, value, /*isSigned=*/true));

Block *target;
SmallVector<OpAsmParser::UnresolvedOperand> operands;
SmallVector<Type> operandTypes;
if (failed(parser.parseColon()) || failed(parser.parseSuccessor(target)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseOperandList(operands,
OpAsmParser::Delimiter::None)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
targets.push_back(target);
targetOperands.emplace_back(operands);
targetOperandTypes.emplace_back(operandTypes);
}

if (!values.empty()) {
ShapedType literalType =
VectorType::get(static_cast<int64_t>(values.size()), selectorType);
literals = DenseIntElementsAttr::get(literalType, values);
}
return success();
}

static void
printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type selectorType,
Block *defaultTarget, OperandRange defaultOperands,
TypeRange defaultOperandTypes, DenseIntElementsAttr literals,
SuccessorRange targets, OperandRangeRange targetOperands,
const TypeRangeRange &targetOperandTypes) {
p << " default: ";
p.printSuccessorAndUseList(defaultTarget, defaultOperands);

if (!literals)
return;

for (auto [index, literal] : llvm::enumerate(literals.getValues<APInt>())) {
p << ',';
p.printNewline();
p << " ";
p << literal.getLimitedValue();
p << ": ";
p.printSuccessorAndUseList(targets[index], targetOperands[index]);
}
p.printNewline();
}

} // namespace mlir::spirv

// TablenGen'erated operation definitions.
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,8 @@ LogicalResult spirv::Deserializer::processInstruction(
return processLoopMerge(operands);
case spirv::Opcode::OpPhi:
return processPhi(operands);
case spirv::Opcode::OpSwitch:
return processSwitch(operands);
case spirv::Opcode::OpUndef:
return processUndef(operands);
default:
Expand Down
32 changes: 32 additions & 0 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2292,6 +2292,38 @@ LogicalResult spirv::Deserializer::processPhi(ArrayRef<uint32_t> operands) {
return success();
}

LogicalResult spirv::Deserializer::processSwitch(ArrayRef<uint32_t> operands) {
if (!curBlock)
return emitError(unknownLoc, "OpSwitch must appear in a block");

if (operands.size() < 2)
return emitError(unknownLoc, "OpSwitch must at least specify selector and "
"a default target");

if (operands.size() % 2)
return emitError(unknownLoc,
"OpSwitch must at have an even number of operands: "
"selector, default target and any number of literal and "
"label <id> pairs");

Value selector = getValue(operands[0]);
Block *defaultBlock = getOrCreateBlock(operands[1]);
Location loc = createFileLineColLoc(opBuilder);

SmallVector<int32_t> literals;
SmallVector<Block *> blocks;
for (unsigned i = 2, e = operands.size(); i < e; i += 2) {
literals.push_back(operands[i]);
blocks.push_back(getOrCreateBlock(operands[i + 1]));
}

SmallVector<ValueRange> targetOperands(blocks.size(), {});
spirv::SwitchOp::create(opBuilder, loc, selector, defaultBlock,
ArrayRef<Value>(), literals, blocks, targetOperands);

return success();
}

namespace {
/// A class for putting all blocks in a structured selection/loop in a
/// spirv.mlir.selection/spirv.mlir.loop op.
Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,9 @@ class Deserializer {
/// Processes a SPIR-V OpPhi instruction with the given `operands`.
LogicalResult processPhi(ArrayRef<uint32_t> operands);

/// Processes a SPIR-V OpSwitch instruction with the given `operands`.
LogicalResult processSwitch(ArrayRef<uint32_t> operands);

/// Creates block arguments on predecessors previously recorded when handling
/// OpPhi instructions.
LogicalResult wireUpBlockArgument();
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,27 @@ LogicalResult Serializer::processBranchOp(spirv::BranchOp branchOp) {
return success();
}

LogicalResult Serializer::processSwitchOp(spirv::SwitchOp switchOp) {
uint32_t selectorID = getValueID(switchOp.getSelector());
uint32_t defaultLabelID = getOrCreateBlockID(switchOp.getDefaultTarget());
SmallVector<uint32_t> arguments{selectorID, defaultLabelID};

std::optional<mlir::DenseIntElementsAttr> literals = switchOp.getLiterals();
BlockRange targets = switchOp.getTargets();
if (literals) {
for (auto [literal, target] : llvm::zip_equal(*literals, targets)) {
arguments.push_back(literal.getLimitedValue());
uint32_t targetLabelID = getOrCreateBlockID(target);
arguments.push_back(targetLabelID);
}
}

if (failed(emitDebugLine(functionBody, switchOp.getLoc())))
return failure();
encodeInstructionInto(functionBody, spirv::Opcode::OpSwitch, arguments);
return success();
}

LogicalResult Serializer::processAddressOfOp(spirv::AddressOfOp addressOfOp) {
auto varName = addressOfOp.getVariable();
auto variableID = getVariableID(varName);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
.Case([&](spirv::SpecConstantOperationOp op) {
return processSpecConstantOperationOp(op);
})
.Case([&](spirv::SwitchOp op) { return processSwitchOp(op); })
.Case([&](spirv::UndefOp op) { return processUndefOp(op); })
.Case([&](spirv::VariableOp op) { return processVariableOp(op); })

Expand Down
Loading