@@ -669,6 +669,200 @@ spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
669669 return success ();
670670}
671671
672+ LogicalResult
673+ spirv::Deserializer::processGraphEntryPointARM (ArrayRef<uint32_t > operands) {
674+ if (operands.size () < 2 ) {
675+ return emitError (unknownLoc,
676+ " missing graph defintion in OpGraphEntryPointARM" );
677+ }
678+
679+ unsigned wordIndex = 0 ;
680+ uint32_t graphID = operands[wordIndex++];
681+ if (!graphMap.contains (graphID)) {
682+ return emitError (unknownLoc,
683+ " missing graph definition/declaration with id " )
684+ << graphID;
685+ }
686+
687+ spirv::GraphARMOp graphARM = graphMap[graphID];
688+ StringRef name = decodeStringLiteral (operands, wordIndex);
689+ graphARM.setSymName (name);
690+ graphARM.setEntryPoint (true );
691+
692+ SmallVector<Attribute, 4 > interface;
693+ for (int64_t size = operands.size (); wordIndex < size; ++wordIndex) {
694+ if (spirv::GlobalVariableOp arg = getGlobalVariable (operands[wordIndex])) {
695+ interface.push_back (SymbolRefAttr::get (arg.getOperation ()));
696+ } else {
697+ return emitError (unknownLoc, " undefined result <id> " )
698+ << operands[wordIndex] << " while decoding OpGraphEntryPoint" ;
699+ }
700+ }
701+
702+ // RAII guard to reset the insertion point to previous value when done.
703+ OpBuilder::InsertionGuard insertionGuard (opBuilder);
704+ opBuilder.setInsertionPoint (graphARM);
705+ opBuilder.create <spirv::GraphEntryPointARMOp>(
706+ unknownLoc, SymbolRefAttr::get (opBuilder.getContext (), name),
707+ opBuilder.getArrayAttr (interface));
708+
709+ return success ();
710+ }
711+
712+ LogicalResult
713+ spirv::Deserializer::processGraphARM (ArrayRef<uint32_t > operands) {
714+ if (curGraph) {
715+ return emitError (unknownLoc, " found graph inside graph" );
716+ }
717+ // Get the result type.
718+ if (operands.size () < 2 ) {
719+ return emitError (unknownLoc, " OpGraphARM must have at least 2 parameters" );
720+ }
721+
722+ Type type = getType (operands[0 ]);
723+ if (!type || !isa<GraphType>(type)) {
724+ return emitError (unknownLoc, " unknown graph type from <id> " )
725+ << operands[0 ];
726+ }
727+ auto graphType = cast<GraphType>(type);
728+ if (graphType.getNumResults () <= 0 ) {
729+ return emitError (unknownLoc, " expected at least one result" );
730+ }
731+
732+ uint32_t graphID = operands[1 ];
733+ if (graphMap.count (graphID)) {
734+ return emitError (unknownLoc, " duplicate graph definition/declaration" );
735+ }
736+
737+ std::string graphName = getGraphSymbol (graphID);
738+ auto graphOp =
739+ opBuilder.create <spirv::GraphARMOp>(unknownLoc, graphName, graphType);
740+ curGraph = graphMap[graphID] = graphOp;
741+ Block *entryBlock = graphOp.addEntryBlock ();
742+ LLVM_DEBUG ({
743+ logger.startLine ()
744+ << " //===-------------------------------------------===//\n " ;
745+ logger.startLine () << " [graph] name: " << graphName << " \n " ;
746+ logger.startLine () << " [graph] type: " << graphType << " \n " ;
747+ logger.startLine () << " [graph] ID: " << graphID << " \n " ;
748+ logger.startLine () << " [graph] entry block: " << entryBlock << " \n " ;
749+ logger.indent ();
750+ });
751+
752+ // Parse the op argument instructions.
753+ for (auto [index, argType] : llvm::enumerate (graphType.getInputs ())) {
754+ spirv::Opcode opcode;
755+ ArrayRef<uint32_t > operands;
756+ if (failed (sliceInstruction (opcode, operands,
757+ spirv::Opcode::OpGraphInputARM))) {
758+ return failure ();
759+ }
760+ if (operands.size () != 3 ) {
761+ return emitError (unknownLoc, " expected result type, result <id> and "
762+ " input index for OpGraphInputARM" );
763+ }
764+
765+ Type argDefinedType = getType (operands[0 ]);
766+ if (!argDefinedType) {
767+ return emitError (unknownLoc, " unknown operand type <id> " ) << operands[0 ];
768+ }
769+
770+ if (argDefinedType != argType) {
771+ return emitError (unknownLoc,
772+ " mismatch in argument type between graph type "
773+ " definition " )
774+ << graphType << " and argument type definition " << argDefinedType
775+ << " at argument " << index;
776+ }
777+ if (getValue (operands[1 ])) {
778+ return emitError (unknownLoc, " duplicate definition of result <id> " )
779+ << operands[1 ];
780+ }
781+
782+ IntegerAttr inputIndexAttr = getConstantInt (operands[2 ]);
783+ if (!inputIndexAttr) {
784+ return emitError (unknownLoc,
785+ " unable to read inputIndex value from constant op " )
786+ << operands[2 ];
787+ }
788+ BlockArgument argValue = graphOp.getArgument (inputIndexAttr.getInt ());
789+ valueMap[operands[1 ]] = argValue;
790+ }
791+
792+ graphOutputs.resize (graphType.getNumResults ());
793+
794+ // RAII guard to reset the insertion point to the module's region after
795+ // deserializing the body of this function.
796+ OpBuilder::InsertionGuard moduleInsertionGuard (opBuilder);
797+
798+ blockMap[graphID] = entryBlock;
799+ if (failed (createGraphBlock (graphID))) {
800+ return failure ();
801+ }
802+
803+ // Process all the instructions in the graph until and including
804+ // OpGraphEndARM.
805+ spirv::Opcode opcode;
806+ ArrayRef<uint32_t > instOperands;
807+ do {
808+ if (failed (sliceInstruction (opcode, instOperands, std::nullopt ))) {
809+ return failure ();
810+ }
811+
812+ if (failed (processInstruction (opcode, instOperands))) {
813+ return failure ();
814+ }
815+ } while (opcode != spirv::Opcode::OpGraphEndARM);
816+
817+ return success ();
818+ }
819+
820+ LogicalResult
821+ spirv::Deserializer::processOpGraphSetOutputARM (ArrayRef<uint32_t > operands) {
822+ if (operands.size () != 2 ) {
823+ return emitError (
824+ unknownLoc,
825+ " expected value id and output index for OpGraphSetOutputARM" );
826+ }
827+
828+ uint32_t id = operands[0 ];
829+ Value value = getValue (id);
830+ if (!value) {
831+ return emitError (unknownLoc, " could not find result <id> " ) << id;
832+ }
833+
834+ IntegerAttr outputIndexAttr = getConstantInt (operands[1 ]);
835+ if (!outputIndexAttr) {
836+ return emitError (unknownLoc,
837+ " unable to read outputIndex value from constant op " )
838+ << operands[1 ];
839+ }
840+ graphOutputs[outputIndexAttr.getInt ()] = value;
841+ return success ();
842+ }
843+
844+ LogicalResult
845+ spirv::Deserializer::processGraphEndARM (ArrayRef<uint32_t > operands) {
846+ // Create GraphOutputsARM instruction.
847+ opBuilder.create <spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
848+
849+ // Process OpGraphEndARM.
850+ if (!operands.empty ()) {
851+ return emitError (unknownLoc, " unexpected operands for OpGraphEndARM" );
852+ }
853+
854+ curBlock = nullptr ;
855+ curGraph = std::nullopt ;
856+ graphOutputs.clear ();
857+
858+ LLVM_DEBUG ({
859+ logger.unindent ();
860+ logger.startLine ()
861+ << " //===-------------------------------------------===//\n " ;
862+ });
863+ return success ();
864+ }
865+
672866std::optional<std::pair<Attribute, Type>>
673867spirv::Deserializer::getConstant (uint32_t id) {
674868 auto constIt = constantMap.find (id);
@@ -701,6 +895,14 @@ std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
701895 return funcName;
702896}
703897
898+ std::string spirv::Deserializer::getGraphSymbol (uint32_t id) {
899+ std::string graphName = nameMap.lookup (id).str ();
900+ if (graphName.empty ()) {
901+ graphName = " spirv_graph_" + std::to_string (id);
902+ }
903+ return graphName;
904+ }
905+
704906std::string spirv::Deserializer::getSpecConstantSymbol (uint32_t id) {
705907 auto constName = nameMap.lookup (id).str ();
706908 if (constName.empty ()) {
@@ -723,6 +925,14 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
723925 return op;
724926}
725927
928+ std::optional<spirv::GraphConstantARMOpMaterializationInfo>
929+ spirv::Deserializer::getGraphConstantARM (uint32_t id) {
930+ auto graphConstIt = graphConstantMap.find (id);
931+ if (graphConstIt == graphConstantMap.end ())
932+ return std::nullopt ;
933+ return graphConstIt->getSecond ();
934+ }
935+
726936LogicalResult
727937spirv::Deserializer::processGlobalVariable (ArrayRef<uint32_t > operands) {
728938 unsigned wordIndex = 0 ;
@@ -944,6 +1154,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
9441154 return processMatrixType (operands);
9451155 case spirv::Opcode::OpTypeTensorARM:
9461156 return processTensorARMType (operands);
1157+ case spirv::Opcode::OpTypeGraphARM:
1158+ return processGraphTypeARM (operands);
9471159 default :
9481160 return emitError (unknownLoc, " unhandled type instruction" );
9491161 }
@@ -1311,6 +1523,35 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
13111523 return success ();
13121524}
13131525
1526+ LogicalResult
1527+ spirv::Deserializer::processGraphTypeARM (ArrayRef<uint32_t > operands) {
1528+ unsigned size = operands.size ();
1529+ if (size < 2 ) {
1530+ return emitError (unknownLoc, " OpTypeGraphARM must have at least 2 operands "
1531+ " (result_id, num_inputs, (inout0_type, "
1532+ " inout1_type, ...))" )
1533+ << size;
1534+ }
1535+ uint32_t numInputs = operands[1 ];
1536+ SmallVector<Type, 1 > argTypes;
1537+ SmallVector<Type, 1 > returnTypes;
1538+ for (unsigned i = 2 ; i < size; ++i) {
1539+ Type inOutTy = getType (operands[i]);
1540+ if (!inOutTy) {
1541+ return emitError (unknownLoc,
1542+ " OpTypeGraphARM references undefined element type." )
1543+ << operands[i];
1544+ }
1545+ if (i - 2 >= numInputs) {
1546+ returnTypes.push_back (inOutTy);
1547+ } else {
1548+ argTypes.push_back (inOutTy);
1549+ }
1550+ }
1551+ typeMap[operands[0 ]] = GraphType::get (context, argTypes, returnTypes);
1552+ return success ();
1553+ }
1554+
13141555LogicalResult
13151556spirv::Deserializer::processTypeForwardPointer (ArrayRef<uint32_t > operands) {
13161557 if (operands.size () != 2 )
@@ -1823,6 +2064,34 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
18232064 << resultType;
18242065}
18252066
2067+ LogicalResult
2068+ spirv::Deserializer::processGraphConstantARM (ArrayRef<uint32_t > operands) {
2069+ if (operands.size () < 3 ) {
2070+ return emitError (unknownLoc)
2071+ << " OpGraphConstantARM must have at least 2 operands" ;
2072+ }
2073+
2074+ Type resultType = getType (operands[0 ]);
2075+ if (!resultType) {
2076+ return emitError (unknownLoc, " undefined result type from <id> " )
2077+ << operands[0 ];
2078+ }
2079+
2080+ uint32_t resultID = operands[1 ];
2081+
2082+ if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2083+ return emitError (unknownLoc, " result must be of type OpTypeTensorARM" );
2084+ }
2085+
2086+ APInt graph_constant_id = APInt (32 , operands[2 ], /* isSigned=*/ true );
2087+ Type i32Ty = opBuilder.getIntegerType (32 );
2088+ IntegerAttr attr = opBuilder.getIntegerAttr (i32Ty, graph_constant_id);
2089+ graphConstantMap.try_emplace (
2090+ resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
2091+
2092+ return success ();
2093+ }
2094+
18262095// ===----------------------------------------------------------------------===//
18272096// Control flow
18282097// ===----------------------------------------------------------------------===//
@@ -1920,6 +2189,24 @@ LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
19202189 return success ();
19212190}
19222191
2192+ LogicalResult spirv::Deserializer::createGraphBlock (uint32_t graphID) {
2193+ if (!curGraph) {
2194+ return emitError (unknownLoc, " a graph block must appear inside a graph" );
2195+ }
2196+
2197+ // We may have forward declared this block.
2198+ Block *block = getOrCreateBlock (graphID);
2199+ LLVM_DEBUG (logger.startLine ()
2200+ << " [block] populating block " << block << " \n " );
2201+ // If we have seen this block, make sure it was just a forward declaration.
2202+ assert (block->empty () && " re-deserialize the same block!" );
2203+
2204+ opBuilder.setInsertionPointToStart (block);
2205+ blockMap[graphID] = curBlock = block;
2206+
2207+ return success ();
2208+ }
2209+
19232210LogicalResult
19242211spirv::Deserializer::processSelectionMerge (ArrayRef<uint32_t > operands) {
19252212 if (!curBlock) {
0 commit comments