Skip to content

Commit 1a746b6

Browse files
[mlir][spirv] Add support for SPV_ARM_graph extension - part 2 (#156665)
This is the second patch to add support for the `SPV_ARM_graph` SPIR-V extension to MLIR’s SPIR-V dialect. The extension introduces a new `Graph` abstraction for expressing dataflow computations over full resources. The part 2 implementation includes: - Serialization and deserialization support for: - `OpGraphARM`, `OpGraphInputARM`, `OpGraphSetOutputARM`, `OpGraphEndARM` - `OpGraphEntryPointARM`, `OpGraphConstantARM`, `OpTypeGraphARM` - Tests covering binary round-tripping. Graphs currently support only `SPV_ARM_tensors`, but are designed to generalize to other resource types, such as images. Spec: KhronosGroup/SPIRV-Registry#346 RFC: https://discourse.llvm.org/t/rfc-add-support-for-spv-arm-graph-extension-in-mlir-spir-v-dialect/86947 --------- Signed-off-by: Davide Grohmann <[email protected]>
1 parent b39da34 commit 1a746b6

File tree

7 files changed

+619
-7
lines changed

7 files changed

+619
-7
lines changed

mlir/lib/Target/SPIRV/Deserialization/DeserializeOps.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ Value spirv::Deserializer::getValue(uint32_t id) {
8686
if (auto undef = getUndefType(id)) {
8787
return spirv::UndefOp::create(opBuilder, unknownLoc, undef);
8888
}
89+
if (std::optional<spirv::GraphConstantARMOpMaterializationInfo>
90+
graphConstantARMInfo = getGraphConstantARM(id)) {
91+
IntegerAttr graphConstantID = graphConstantARMInfo->graphConstantID;
92+
Type resultType = graphConstantARMInfo->resultType;
93+
return spirv::GraphConstantARMOp::create(opBuilder, unknownLoc, resultType,
94+
graphConstantID);
95+
}
8996
return valueMap.lookup(id);
9097
}
9198

@@ -180,6 +187,7 @@ LogicalResult spirv::Deserializer::processInstruction(
180187
case spirv::Opcode::OpTypeStruct:
181188
case spirv::Opcode::OpTypePointer:
182189
case spirv::Opcode::OpTypeTensorARM:
190+
case spirv::Opcode::OpTypeGraphARM:
183191
case spirv::Opcode::OpTypeCooperativeMatrixKHR:
184192
return processType(opcode, operands);
185193
case spirv::Opcode::OpTypeForwardPointer:
@@ -208,12 +216,26 @@ LogicalResult spirv::Deserializer::processInstruction(
208216
return processConstantBool(/*isTrue=*/false, operands, /*isSpec=*/true);
209217
case spirv::Opcode::OpConstantNull:
210218
return processConstantNull(operands);
219+
case spirv::Opcode::OpGraphConstantARM:
220+
return processGraphConstantARM(operands);
211221
case spirv::Opcode::OpDecorate:
212222
return processDecoration(operands);
213223
case spirv::Opcode::OpMemberDecorate:
214224
return processMemberDecoration(operands);
215225
case spirv::Opcode::OpFunction:
216226
return processFunction(operands);
227+
case spirv::Opcode::OpGraphEntryPointARM:
228+
if (deferInstructions) {
229+
deferredInstructions.emplace_back(opcode, operands);
230+
return success();
231+
}
232+
return processGraphEntryPointARM(operands);
233+
case spirv::Opcode::OpGraphARM:
234+
return processGraphARM(operands);
235+
case spirv::Opcode::OpGraphSetOutputARM:
236+
return processOpGraphSetOutputARM(operands);
237+
case spirv::Opcode::OpGraphEndARM:
238+
return processGraphEndARM(operands);
217239
case spirv::Opcode::OpLabel:
218240
return processLabel(operands);
219241
case spirv::Opcode::OpBranch:

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,200 @@ spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
669669
return success();
670670
}
671671

672+
LogicalResult
673+
spirv::Deserializer::processGraphEntryPointARM(ArrayRef<uint32_t> operands) {
674+
if (operands.size() < 2) {
675+
return emitError(unknownLoc,
676+
"missing graph defintion in OpGraphEntryPointARM");
677+
}
678+
679+
unsigned wordIndex = 0;
680+
uint32_t graphID = operands[wordIndex++];
681+
if (!graphMap.contains(graphID)) {
682+
return emitError(unknownLoc,
683+
"missing graph definition/declaration with id ")
684+
<< graphID;
685+
}
686+
687+
spirv::GraphARMOp graphARM = graphMap[graphID];
688+
StringRef name = decodeStringLiteral(operands, wordIndex);
689+
graphARM.setSymName(name);
690+
graphARM.setEntryPoint(true);
691+
692+
SmallVector<Attribute, 4> interface;
693+
for (int64_t size = operands.size(); wordIndex < size; ++wordIndex) {
694+
if (spirv::GlobalVariableOp arg = getGlobalVariable(operands[wordIndex])) {
695+
interface.push_back(SymbolRefAttr::get(arg.getOperation()));
696+
} else {
697+
return emitError(unknownLoc, "undefined result <id> ")
698+
<< operands[wordIndex] << " while decoding OpGraphEntryPoint";
699+
}
700+
}
701+
702+
// RAII guard to reset the insertion point to previous value when done.
703+
OpBuilder::InsertionGuard insertionGuard(opBuilder);
704+
opBuilder.setInsertionPoint(graphARM);
705+
opBuilder.create<spirv::GraphEntryPointARMOp>(
706+
unknownLoc, SymbolRefAttr::get(opBuilder.getContext(), name),
707+
opBuilder.getArrayAttr(interface));
708+
709+
return success();
710+
}
711+
712+
LogicalResult
713+
spirv::Deserializer::processGraphARM(ArrayRef<uint32_t> operands) {
714+
if (curGraph) {
715+
return emitError(unknownLoc, "found graph inside graph");
716+
}
717+
// Get the result type.
718+
if (operands.size() < 2) {
719+
return emitError(unknownLoc, "OpGraphARM must have at least 2 parameters");
720+
}
721+
722+
Type type = getType(operands[0]);
723+
if (!type || !isa<GraphType>(type)) {
724+
return emitError(unknownLoc, "unknown graph type from <id> ")
725+
<< operands[0];
726+
}
727+
auto graphType = cast<GraphType>(type);
728+
if (graphType.getNumResults() <= 0) {
729+
return emitError(unknownLoc, "expected at least one result");
730+
}
731+
732+
uint32_t graphID = operands[1];
733+
if (graphMap.count(graphID)) {
734+
return emitError(unknownLoc, "duplicate graph definition/declaration");
735+
}
736+
737+
std::string graphName = getGraphSymbol(graphID);
738+
auto graphOp =
739+
opBuilder.create<spirv::GraphARMOp>(unknownLoc, graphName, graphType);
740+
curGraph = graphMap[graphID] = graphOp;
741+
Block *entryBlock = graphOp.addEntryBlock();
742+
LLVM_DEBUG({
743+
logger.startLine()
744+
<< "//===-------------------------------------------===//\n";
745+
logger.startLine() << "[graph] name: " << graphName << "\n";
746+
logger.startLine() << "[graph] type: " << graphType << "\n";
747+
logger.startLine() << "[graph] ID: " << graphID << "\n";
748+
logger.startLine() << "[graph] entry block: " << entryBlock << "\n";
749+
logger.indent();
750+
});
751+
752+
// Parse the op argument instructions.
753+
for (auto [index, argType] : llvm::enumerate(graphType.getInputs())) {
754+
spirv::Opcode opcode;
755+
ArrayRef<uint32_t> operands;
756+
if (failed(sliceInstruction(opcode, operands,
757+
spirv::Opcode::OpGraphInputARM))) {
758+
return failure();
759+
}
760+
if (operands.size() != 3) {
761+
return emitError(unknownLoc, "expected result type, result <id> and "
762+
"input index for OpGraphInputARM");
763+
}
764+
765+
Type argDefinedType = getType(operands[0]);
766+
if (!argDefinedType) {
767+
return emitError(unknownLoc, "unknown operand type <id> ") << operands[0];
768+
}
769+
770+
if (argDefinedType != argType) {
771+
return emitError(unknownLoc,
772+
"mismatch in argument type between graph type "
773+
"definition ")
774+
<< graphType << " and argument type definition " << argDefinedType
775+
<< " at argument " << index;
776+
}
777+
if (getValue(operands[1])) {
778+
return emitError(unknownLoc, "duplicate definition of result <id> ")
779+
<< operands[1];
780+
}
781+
782+
IntegerAttr inputIndexAttr = getConstantInt(operands[2]);
783+
if (!inputIndexAttr) {
784+
return emitError(unknownLoc,
785+
"unable to read inputIndex value from constant op ")
786+
<< operands[2];
787+
}
788+
BlockArgument argValue = graphOp.getArgument(inputIndexAttr.getInt());
789+
valueMap[operands[1]] = argValue;
790+
}
791+
792+
graphOutputs.resize(graphType.getNumResults());
793+
794+
// RAII guard to reset the insertion point to the module's region after
795+
// deserializing the body of this function.
796+
OpBuilder::InsertionGuard moduleInsertionGuard(opBuilder);
797+
798+
blockMap[graphID] = entryBlock;
799+
if (failed(createGraphBlock(graphID))) {
800+
return failure();
801+
}
802+
803+
// Process all the instructions in the graph until and including
804+
// OpGraphEndARM.
805+
spirv::Opcode opcode;
806+
ArrayRef<uint32_t> instOperands;
807+
do {
808+
if (failed(sliceInstruction(opcode, instOperands, std::nullopt))) {
809+
return failure();
810+
}
811+
812+
if (failed(processInstruction(opcode, instOperands))) {
813+
return failure();
814+
}
815+
} while (opcode != spirv::Opcode::OpGraphEndARM);
816+
817+
return success();
818+
}
819+
820+
LogicalResult
821+
spirv::Deserializer::processOpGraphSetOutputARM(ArrayRef<uint32_t> operands) {
822+
if (operands.size() != 2) {
823+
return emitError(
824+
unknownLoc,
825+
"expected value id and output index for OpGraphSetOutputARM");
826+
}
827+
828+
uint32_t id = operands[0];
829+
Value value = getValue(id);
830+
if (!value) {
831+
return emitError(unknownLoc, "could not find result <id> ") << id;
832+
}
833+
834+
IntegerAttr outputIndexAttr = getConstantInt(operands[1]);
835+
if (!outputIndexAttr) {
836+
return emitError(unknownLoc,
837+
"unable to read outputIndex value from constant op ")
838+
<< operands[1];
839+
}
840+
graphOutputs[outputIndexAttr.getInt()] = value;
841+
return success();
842+
}
843+
844+
LogicalResult
845+
spirv::Deserializer::processGraphEndARM(ArrayRef<uint32_t> operands) {
846+
// Create GraphOutputsARM instruction.
847+
opBuilder.create<spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
848+
849+
// Process OpGraphEndARM.
850+
if (!operands.empty()) {
851+
return emitError(unknownLoc, "unexpected operands for OpGraphEndARM");
852+
}
853+
854+
curBlock = nullptr;
855+
curGraph = std::nullopt;
856+
graphOutputs.clear();
857+
858+
LLVM_DEBUG({
859+
logger.unindent();
860+
logger.startLine()
861+
<< "//===-------------------------------------------===//\n";
862+
});
863+
return success();
864+
}
865+
672866
std::optional<std::pair<Attribute, Type>>
673867
spirv::Deserializer::getConstant(uint32_t id) {
674868
auto constIt = constantMap.find(id);
@@ -701,6 +895,14 @@ std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
701895
return funcName;
702896
}
703897

898+
std::string spirv::Deserializer::getGraphSymbol(uint32_t id) {
899+
std::string graphName = nameMap.lookup(id).str();
900+
if (graphName.empty()) {
901+
graphName = "spirv_graph_" + std::to_string(id);
902+
}
903+
return graphName;
904+
}
905+
704906
std::string spirv::Deserializer::getSpecConstantSymbol(uint32_t id) {
705907
auto constName = nameMap.lookup(id).str();
706908
if (constName.empty()) {
@@ -723,6 +925,14 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
723925
return op;
724926
}
725927

928+
std::optional<spirv::GraphConstantARMOpMaterializationInfo>
929+
spirv::Deserializer::getGraphConstantARM(uint32_t id) {
930+
auto graphConstIt = graphConstantMap.find(id);
931+
if (graphConstIt == graphConstantMap.end())
932+
return std::nullopt;
933+
return graphConstIt->getSecond();
934+
}
935+
726936
LogicalResult
727937
spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
728938
unsigned wordIndex = 0;
@@ -944,6 +1154,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
9441154
return processMatrixType(operands);
9451155
case spirv::Opcode::OpTypeTensorARM:
9461156
return processTensorARMType(operands);
1157+
case spirv::Opcode::OpTypeGraphARM:
1158+
return processGraphTypeARM(operands);
9471159
default:
9481160
return emitError(unknownLoc, "unhandled type instruction");
9491161
}
@@ -1311,6 +1523,35 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
13111523
return success();
13121524
}
13131525

1526+
LogicalResult
1527+
spirv::Deserializer::processGraphTypeARM(ArrayRef<uint32_t> operands) {
1528+
unsigned size = operands.size();
1529+
if (size < 2) {
1530+
return emitError(unknownLoc, "OpTypeGraphARM must have at least 2 operands "
1531+
"(result_id, num_inputs, (inout0_type, "
1532+
"inout1_type, ...))")
1533+
<< size;
1534+
}
1535+
uint32_t numInputs = operands[1];
1536+
SmallVector<Type, 1> argTypes;
1537+
SmallVector<Type, 1> returnTypes;
1538+
for (unsigned i = 2; i < size; ++i) {
1539+
Type inOutTy = getType(operands[i]);
1540+
if (!inOutTy) {
1541+
return emitError(unknownLoc,
1542+
"OpTypeGraphARM references undefined element type.")
1543+
<< operands[i];
1544+
}
1545+
if (i - 2 >= numInputs) {
1546+
returnTypes.push_back(inOutTy);
1547+
} else {
1548+
argTypes.push_back(inOutTy);
1549+
}
1550+
}
1551+
typeMap[operands[0]] = GraphType::get(context, argTypes, returnTypes);
1552+
return success();
1553+
}
1554+
13141555
LogicalResult
13151556
spirv::Deserializer::processTypeForwardPointer(ArrayRef<uint32_t> operands) {
13161557
if (operands.size() != 2)
@@ -1823,6 +2064,34 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
18232064
<< resultType;
18242065
}
18252066

2067+
LogicalResult
2068+
spirv::Deserializer::processGraphConstantARM(ArrayRef<uint32_t> operands) {
2069+
if (operands.size() < 3) {
2070+
return emitError(unknownLoc)
2071+
<< "OpGraphConstantARM must have at least 2 operands";
2072+
}
2073+
2074+
Type resultType = getType(operands[0]);
2075+
if (!resultType) {
2076+
return emitError(unknownLoc, "undefined result type from <id> ")
2077+
<< operands[0];
2078+
}
2079+
2080+
uint32_t resultID = operands[1];
2081+
2082+
if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2083+
return emitError(unknownLoc, "result must be of type OpTypeTensorARM");
2084+
}
2085+
2086+
APInt graph_constant_id = APInt(32, operands[2], /*isSigned=*/true);
2087+
Type i32Ty = opBuilder.getIntegerType(32);
2088+
IntegerAttr attr = opBuilder.getIntegerAttr(i32Ty, graph_constant_id);
2089+
graphConstantMap.try_emplace(
2090+
resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
2091+
2092+
return success();
2093+
}
2094+
18262095
//===----------------------------------------------------------------------===//
18272096
// Control flow
18282097
//===----------------------------------------------------------------------===//
@@ -1920,6 +2189,24 @@ LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
19202189
return success();
19212190
}
19222191

2192+
LogicalResult spirv::Deserializer::createGraphBlock(uint32_t graphID) {
2193+
if (!curGraph) {
2194+
return emitError(unknownLoc, "a graph block must appear inside a graph");
2195+
}
2196+
2197+
// We may have forward declared this block.
2198+
Block *block = getOrCreateBlock(graphID);
2199+
LLVM_DEBUG(logger.startLine()
2200+
<< "[block] populating block " << block << "\n");
2201+
// If we have seen this block, make sure it was just a forward declaration.
2202+
assert(block->empty() && "re-deserialize the same block!");
2203+
2204+
opBuilder.setInsertionPointToStart(block);
2205+
blockMap[graphID] = curBlock = block;
2206+
2207+
return success();
2208+
}
2209+
19232210
LogicalResult
19242211
spirv::Deserializer::processSelectionMerge(ArrayRef<uint32_t> operands) {
19252212
if (!curBlock) {

0 commit comments

Comments
 (0)