Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ Value spirv::Deserializer::getValue(uint32_t id) {
if (auto undef = getUndefType(id)) {
return spirv::UndefOp::create(opBuilder, unknownLoc, undef);
}
if (std::optional<spirv::GraphConstantARMOpMaterializationInfo>
graphConstantARMInfo = getGraphConstantARM(id)) {
IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID;
Type resultType = graphConstantARMInfo->resultType;
return opBuilder.create<spirv::GraphConstantARMOp>(unknownLoc, resultType,
graphConstantID);
}
return valueMap.lookup(id);
}

Expand Down Expand Up @@ -180,6 +187,7 @@ LogicalResult spirv::Deserializer::processInstruction(
case spirv::Opcode::OpTypeStruct:
case spirv::Opcode::OpTypePointer:
case spirv::Opcode::OpTypeTensorARM:
case spirv::Opcode::OpTypeGraphARM:
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
return processType(opcode, operands);
case spirv::Opcode::OpTypeForwardPointer:
Expand Down Expand Up @@ -208,12 +216,26 @@ LogicalResult spirv::Deserializer::processInstruction(
return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
case spirv::Opcode::OpConstantNull:
return processConstantNull(operands);
case spirv::Opcode::OpGraphConstantARM:
return processGraphConstantARM(operands);
case spirv::Opcode::OpDecorate:
return processDecoration(operands);
case spirv::Opcode::OpMemberDecorate:
return processMemberDecoration(operands);
case spirv::Opcode::OpFunction:
return processFunction(operands);
case spirv::Opcode::OpGraphEntryPointARM:
if (deferInstructions) {
deferredInstructions.emplace_back(opcode, operands);
return success();
}
return processGraphEntryPointARM(operands);
case spirv::Opcode::OpGraphARM:
return processGraphARM(operands);
case spirv::Opcode::OpGraphSetOutputARM:
return processOpGraphSetOutputARM(operands);
case spirv::Opcode::OpGraphEndARM:
return processGraphEndARM(operands);
case spirv::Opcode::OpLabel:
return processLabel(operands);
case spirv::Opcode::OpBranch:
Expand Down
287 changes: 287 additions & 0 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,200 @@ spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
return success();
}

LogicalResult
spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
if (operands.size() < 2) {
return emitError(unknownLoc,
"missing graph defintion in OpGraphEntryPointARM");
}

unsigned wordIndex = 0;
uint32_t grID = operands[wordIndex++];
if (!graphMap.count(grID)) {
return emitError(unknownLoc,
"missing graph definition/declaration with id ")
<< grID;
}

spirv::GraphARMOp graphARM = graphMap[grID];
StringRef name = decodeStringLiteral(operands, wordIndex);
graphARM.setSymName(name);
graphARM.setEntryPoint(true);

SmallVector<Attribute, 4> interface;
for (int64_t size = operands.size(); wordIndex < size; wordIndex++) {
if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
interface.push_back(SymbolRefAttr::get(arg.getOperation()));
} else {
return emitError(unknownLoc, "undefined result <id> ")
<< operands[wordIndex] << " while decoding OpGraphEntryPoint";
}
}

// RAII guard to reset the insertion point to previous value when done.
OpBuilder::InsertionGuard insertionGuard(opBuilder);
opBuilder.setInsertionPoint(graphARM);
opBuilder.create<spirv::GraphEntryPointARMOp>(
unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
opBuilder.getArrayAttr(interface));

return success();
}

LogicalResult
spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
if (curGraph) {
return emitError(unknownLoc, "found graph inside graph");
}
// Get the result type.
if (operands.size() < 2) {
return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
}

Type type = getType(operands[0]);
if (!type || !isa<GraphType>(type)) {
return emitError(unknownLoc, "unknown graph type from <id> ")
<< operands[0];
}
auto graphType = cast<GraphType>(type);
if (graphType.getNumResults() <= 0) {
return emitError(unknownLoc, "expected at least one result");
}

uint32_t grID = operands[1];
if (graphMap.count(grID)) {
return emitError(unknownLoc, "duplicate graph definition/declaration");
}

std::string grName = getGraphSymbol(grID);
auto graphOp =
opBuilder.create<spirv::GraphARMOp>(unknownLoc, grName, graphType);
curGraph = graphMap[grID] = graphOp;
Block *entryBlock = graphOp.addEntryBlock();
LLVM_DEBUG({
logger.startLine()
<< "//===-------------------------------------------===//\n";
logger.startLine() << "[graph] name: " << grName << "\n";
logger.startLine() << "[graph] type: " << graphType << "\n";
logger.startLine() << "[graph] ID: " << grID << "\n";
logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
logger.indent();
});

// Parse the op argument instructions.
for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
spirv::Opcode opcode;
ArrayRef<uint32_t> operands;
if (failed(sliceInstruction(opcode, operands,
spirv::Opcode::OpGraphInputARM))) {
return failure();
}
if (operands.size() != 3) {
return emitError(unknownLoc, "expected result type, result <id> and "
"input index for OpGraphInputARM");
}

Type argDefinedType = getType(operands[0]);
if (!argDefinedType) {
return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
}

if (argDefinedType != argType) {
return emitError(unknownLoc,
"mismatch in argument type between graph type "
"definition ")
<< graphType << " and argument type definition " << argDefinedType
<< " at argument " << index;
}
if (getValue(operands[1])) {
return emitError(unknownLoc, "duplicate definition of result <id> ")
<< operands[1];
}

IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
if (!inputIndexAttr) {
return emitError(unknownLoc,
"unable to read inputIndex value from constant op ")
<< operands[2];
}
BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
valueMap[operands[1]] = argValue;
}

graphOutputs.resize(graphType.getNumResults());

// RAII guard to reset the insertion point to the module's region after
// deserializing the body of this function.
OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);

blockMap[grID] = entryBlock;
if (failed(createGraphBlock(grID))) {
return failure();
}

// Process all the instructions in the graph until and including
// OpGraphEndARM.
spirv::Opcode opcode;
ArrayRef<uint32_t> instOperands;
do {
if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
return failure();
}

if (failed(processInstruction(opcode, instOperands))) {
return failure();
}
} while (opcode != spirv::Opcode::OpGraphEndARM);

return success();
}

LogicalResult
spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
if (operands.size() != 2) {
return emitError(
unknownLoc,
"expected value id and output index for OpGraphSetOutputARM");
}

uint32_t id = operands[0];
Value value = getValue(id);
if (!value) {
return emitError(unknownLoc, "could not find result <id> ") << id;
}

IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
if (!outputIndexAttr) {
return emitError(unknownLoc,
"unable to read outputIndex value from constant op ")
<< operands[1];
}
graphOutputs[outputIndexAttr.getInt()] = value;
return success();
}

LogicalResult
spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
// Create GraphOutputsARM instruction.
opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);

// Process OpGraphEndARM.
if (!operands.empty()) {
return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
}

curBlock = nullptr;
curGraph = std::nullopt;
graphOutputs.clear();

LLVM_DEBUG({
logger.unindent();
logger.startLine()
<< "//===-------------------------------------------===//\n";
});
return success();
}

std::optional<std::pair<Attribute, Type>>
spirv::Deserializer::getConstant(uint32_t id) {
auto constIt = constantMap.find(id);
Expand Down Expand Up @@ -701,6 +895,14 @@ std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
return funcName;
}

std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
std::string graphName = nameMap.lookup(id).str();
if (graphName.empty()) {
graphName = "spirv_graph_" + std::to_string(id);
}
return graphName;
}

std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
auto constName = nameMap.lookup(id).str();
if (constName.empty()) {
Expand All @@ -723,6 +925,14 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
return op;
}

std::optional<spirv::GraphConstantARMOpMaterializationInfo>
spirv::Deserializer::getGraphConstantARM(uint32_t id) {
auto graphConstIt = graphConstantMap.find(id);
if (graphConstIt == graphConstantMap.end())
return std::nullopt;
return graphConstIt->getSecond();
}

LogicalResult
spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
unsigned wordIndex = 0;
Expand Down Expand Up @@ -944,6 +1154,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
return processMatrixType(operands);
case spirv::Opcode::OpTypeTensorARM:
return processTensorARMType(operands);
case spirv::Opcode::OpTypeGraphARM:
return processGraphTypeARM(operands);
default:
return emitError(unknownLoc, "unhandled type instruction");
}
Expand Down Expand Up @@ -1311,6 +1523,35 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
return success();
}

LogicalResult
spirv::Deserializer::processGraphTypeARM(ArrayRef<uint32_t> operands) {
unsigned size = operands.size();
if (size < 2) {
return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
"(result_id, num_inputs, (inout0_type, "
"inout1_type, ...))")
<< size;
}
uint32_t numInputs = operands[1];
SmallVector<Type, 1> argTypes;
SmallVector<Type, 1> returnTypes;
for (unsigned i = 2; i < size; i++) {
Type inOutTy = getType(operands[i]);
if (!inOutTy) {
return emitError(unknownLoc,
"OpTypeGraphARM references undefined element type.")
<< operands[i];
}
if (i - 2 >= numInputs) {
returnTypes.push_back(inOutTy);
} else {
argTypes.push_back(inOutTy);
}
}
typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
return success();
}

LogicalResult
spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
if (operands.size() != 2)
Expand Down Expand Up @@ -1823,6 +2064,34 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
<< resultType;
}

LogicalResult
spirv::Deserializer::processGraphConstantARM(ArrayRef<uint32_t> operands) {
if (operands.size() < 3) {
return emitError(unknownLoc)
<< "OpGraphConstantARM must have at least 2 operands";
}

Type resultType = getType(operands[0]);
if (!resultType) {
return emitError(unknownLoc, "undefined result type from <id> ")
<< operands[0];
}

uint32_t resultID = operands[1];

if (!dyn_cast<spirv::TensorArmType>(resultType)) {
return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
}

APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
Type i32Ty = opBuilder.getIntegerType(32);
IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
graphConstantMap.try_emplace(
resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});

return success();
}

//===----------------------------------------------------------------------===//
// Control flow
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1920,6 +2189,24 @@ LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
return success();
}

LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
if (!curGraph) {
return emitError(unknownLoc, "a graph block must appear inside a graph");
}

// We may have forward declared this block.
Block *block = getOrCreateBlock(graphID);
LLVM_DEBUG(logger.startLine()
<< "[block] populating block " << block << "\n");
// If we have seen this block, make sure it was just a forward declaration.
assert(block->empty() && "re-deserialize the same block!");

opBuilder.setInsertionPointToStart(block);
blockMap[graphID] = curBlock = block;

return success();
}

LogicalResult
spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
if (!curBlock) {
Expand Down
Loading
Loading