@@ -669,6 +669,200 @@ spirv::Deserializer::processFunctionEnd(ArrayRef<uint32_t> operands) {
669
669
return success ();
670
670
}
671
671
672
+ LogicalResult
673
+ spirv::Deserializer::processGraphEntryPointARM (ArrayRef<uint32_t > operands) {
674
+ if (operands.size () < 2 ) {
675
+ return emitError (unknownLoc,
676
+ " missing graph defintion in OpGraphEntryPointARM" );
677
+ }
678
+
679
+ unsigned wordIndex = 0 ;
680
+ uint32_t graphID = operands[wordIndex++];
681
+ if (!graphMap.contains (graphID)) {
682
+ return emitError (unknownLoc,
683
+ " missing graph definition/declaration with id " )
684
+ << graphID;
685
+ }
686
+
687
+ spirv::GraphARMOp graphARM = graphMap[graphID];
688
+ StringRef name = decodeStringLiteral (operands, wordIndex);
689
+ graphARM.setSymName (name);
690
+ graphARM.setEntryPoint (true );
691
+
692
+ SmallVector<Attribute, 4 > interface;
693
+ for (int64_t size = operands.size (); wordIndex < size; ++wordIndex) {
694
+ if (spirv::GlobalVariableOp arg = getGlobalVariable (operands[wordIndex])) {
695
+ interface.push_back (SymbolRefAttr::get (arg.getOperation ()));
696
+ } else {
697
+ return emitError (unknownLoc, " undefined result <id> " )
698
+ << operands[wordIndex] << " while decoding OpGraphEntryPoint" ;
699
+ }
700
+ }
701
+
702
+ // RAII guard to reset the insertion point to previous value when done.
703
+ OpBuilder::InsertionGuard insertionGuard (opBuilder);
704
+ opBuilder.setInsertionPoint (graphARM);
705
+ opBuilder.create <spirv::GraphEntryPointARMOp>(
706
+ unknownLoc, SymbolRefAttr::get (opBuilder.getContext (), name),
707
+ opBuilder.getArrayAttr (interface));
708
+
709
+ return success ();
710
+ }
711
+
712
+ LogicalResult
713
+ spirv::Deserializer::processGraphARM (ArrayRef<uint32_t > operands) {
714
+ if (curGraph) {
715
+ return emitError (unknownLoc, " found graph inside graph" );
716
+ }
717
+ // Get the result type.
718
+ if (operands.size () < 2 ) {
719
+ return emitError (unknownLoc, " OpGraphARM must have at least 2 parameters" );
720
+ }
721
+
722
+ Type type = getType (operands[0 ]);
723
+ if (!type || !isa<GraphType>(type)) {
724
+ return emitError (unknownLoc, " unknown graph type from <id> " )
725
+ << operands[0 ];
726
+ }
727
+ auto graphType = cast<GraphType>(type);
728
+ if (graphType.getNumResults () <= 0 ) {
729
+ return emitError (unknownLoc, " expected at least one result" );
730
+ }
731
+
732
+ uint32_t graphID = operands[1 ];
733
+ if (graphMap.count (graphID)) {
734
+ return emitError (unknownLoc, " duplicate graph definition/declaration" );
735
+ }
736
+
737
+ std::string graphName = getGraphSymbol (graphID);
738
+ auto graphOp =
739
+ opBuilder.create <spirv::GraphARMOp>(unknownLoc, graphName, graphType);
740
+ curGraph = graphMap[graphID] = graphOp;
741
+ Block *entryBlock = graphOp.addEntryBlock ();
742
+ LLVM_DEBUG ({
743
+ logger.startLine ()
744
+ << " //===-------------------------------------------===//\n " ;
745
+ logger.startLine () << " [graph] name: " << graphName << " \n " ;
746
+ logger.startLine () << " [graph] type: " << graphType << " \n " ;
747
+ logger.startLine () << " [graph] ID: " << graphID << " \n " ;
748
+ logger.startLine () << " [graph] entry block: " << entryBlock << " \n " ;
749
+ logger.indent ();
750
+ });
751
+
752
+ // Parse the op argument instructions.
753
+ for (auto [index, argType] : llvm::enumerate (graphType.getInputs ())) {
754
+ spirv::Opcode opcode;
755
+ ArrayRef<uint32_t > operands;
756
+ if (failed (sliceInstruction (opcode, operands,
757
+ spirv::Opcode::OpGraphInputARM))) {
758
+ return failure ();
759
+ }
760
+ if (operands.size () != 3 ) {
761
+ return emitError (unknownLoc, " expected result type, result <id> and "
762
+ " input index for OpGraphInputARM" );
763
+ }
764
+
765
+ Type argDefinedType = getType (operands[0 ]);
766
+ if (!argDefinedType) {
767
+ return emitError (unknownLoc, " unknown operand type <id> " ) << operands[0 ];
768
+ }
769
+
770
+ if (argDefinedType != argType) {
771
+ return emitError (unknownLoc,
772
+ " mismatch in argument type between graph type "
773
+ " definition " )
774
+ << graphType << " and argument type definition " << argDefinedType
775
+ << " at argument " << index;
776
+ }
777
+ if (getValue (operands[1 ])) {
778
+ return emitError (unknownLoc, " duplicate definition of result <id> " )
779
+ << operands[1 ];
780
+ }
781
+
782
+ IntegerAttr inputIndexAttr = getConstantInt (operands[2 ]);
783
+ if (!inputIndexAttr) {
784
+ return emitError (unknownLoc,
785
+ " unable to read inputIndex value from constant op " )
786
+ << operands[2 ];
787
+ }
788
+ BlockArgument argValue = graphOp.getArgument (inputIndexAttr.getInt ());
789
+ valueMap[operands[1 ]] = argValue;
790
+ }
791
+
792
+ graphOutputs.resize (graphType.getNumResults ());
793
+
794
+ // RAII guard to reset the insertion point to the module's region after
795
+ // deserializing the body of this function.
796
+ OpBuilder::InsertionGuard moduleInsertionGuard (opBuilder);
797
+
798
+ blockMap[graphID] = entryBlock;
799
+ if (failed (createGraphBlock (graphID))) {
800
+ return failure ();
801
+ }
802
+
803
+ // Process all the instructions in the graph until and including
804
+ // OpGraphEndARM.
805
+ spirv::Opcode opcode;
806
+ ArrayRef<uint32_t > instOperands;
807
+ do {
808
+ if (failed (sliceInstruction (opcode, instOperands, std::nullopt ))) {
809
+ return failure ();
810
+ }
811
+
812
+ if (failed (processInstruction (opcode, instOperands))) {
813
+ return failure ();
814
+ }
815
+ } while (opcode != spirv::Opcode::OpGraphEndARM);
816
+
817
+ return success ();
818
+ }
819
+
820
+ LogicalResult
821
+ spirv::Deserializer::processOpGraphSetOutputARM (ArrayRef<uint32_t > operands) {
822
+ if (operands.size () != 2 ) {
823
+ return emitError (
824
+ unknownLoc,
825
+ " expected value id and output index for OpGraphSetOutputARM" );
826
+ }
827
+
828
+ uint32_t id = operands[0 ];
829
+ Value value = getValue (id);
830
+ if (!value) {
831
+ return emitError (unknownLoc, " could not find result <id> " ) << id;
832
+ }
833
+
834
+ IntegerAttr outputIndexAttr = getConstantInt (operands[1 ]);
835
+ if (!outputIndexAttr) {
836
+ return emitError (unknownLoc,
837
+ " unable to read outputIndex value from constant op " )
838
+ << operands[1 ];
839
+ }
840
+ graphOutputs[outputIndexAttr.getInt ()] = value;
841
+ return success ();
842
+ }
843
+
844
+ LogicalResult
845
+ spirv::Deserializer::processGraphEndARM (ArrayRef<uint32_t > operands) {
846
+ // Create GraphOutputsARM instruction.
847
+ opBuilder.create <spirv::GraphOutputsARMOp>(unknownLoc, graphOutputs);
848
+
849
+ // Process OpGraphEndARM.
850
+ if (!operands.empty ()) {
851
+ return emitError (unknownLoc, " unexpected operands for OpGraphEndARM" );
852
+ }
853
+
854
+ curBlock = nullptr ;
855
+ curGraph = std::nullopt ;
856
+ graphOutputs.clear ();
857
+
858
+ LLVM_DEBUG ({
859
+ logger.unindent ();
860
+ logger.startLine ()
861
+ << " //===-------------------------------------------===//\n " ;
862
+ });
863
+ return success ();
864
+ }
865
+
672
866
std::optional<std::pair<Attribute, Type>>
673
867
spirv::Deserializer::getConstant (uint32_t id) {
674
868
auto constIt = constantMap.find (id);
@@ -701,6 +895,14 @@ std::string spirv::Deserializer::getFunctionSymbol(uint32_t id) {
701
895
return funcName;
702
896
}
703
897
898
+ std::string spirv::Deserializer::getGraphSymbol (uint32_t id) {
899
+ std::string graphName = nameMap.lookup (id).str ();
900
+ if (graphName.empty ()) {
901
+ graphName = " spirv_graph_" + std::to_string (id);
902
+ }
903
+ return graphName;
904
+ }
905
+
704
906
std::string spirv::Deserializer::getSpecConstantSymbol (uint32_t id) {
705
907
auto constName = nameMap.lookup (id).str ();
706
908
if (constName.empty ()) {
@@ -723,6 +925,14 @@ spirv::Deserializer::createSpecConstant(Location loc, uint32_t resultID,
723
925
return op;
724
926
}
725
927
928
+ std::optional<spirv::GraphConstantARMOpMaterializationInfo>
929
+ spirv::Deserializer::getGraphConstantARM (uint32_t id) {
930
+ auto graphConstIt = graphConstantMap.find (id);
931
+ if (graphConstIt == graphConstantMap.end ())
932
+ return std::nullopt ;
933
+ return graphConstIt->getSecond ();
934
+ }
935
+
726
936
LogicalResult
727
937
spirv::Deserializer::processGlobalVariable (ArrayRef<uint32_t > operands) {
728
938
unsigned wordIndex = 0 ;
@@ -944,6 +1154,8 @@ LogicalResult spirv::Deserializer::processType(spirv::Opcode opcode,
944
1154
return processMatrixType (operands);
945
1155
case spirv::Opcode::OpTypeTensorARM:
946
1156
return processTensorARMType (operands);
1157
+ case spirv::Opcode::OpTypeGraphARM:
1158
+ return processGraphTypeARM (operands);
947
1159
default :
948
1160
return emitError (unknownLoc, " unhandled type instruction" );
949
1161
}
@@ -1311,6 +1523,35 @@ spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) {
1311
1523
return success ();
1312
1524
}
1313
1525
1526
+ LogicalResult
1527
+ spirv::Deserializer::processGraphTypeARM (ArrayRef<uint32_t > operands) {
1528
+ unsigned size = operands.size ();
1529
+ if (size < 2 ) {
1530
+ return emitError (unknownLoc, " OpTypeGraphARM must have at least 2 operands "
1531
+ " (result_id, num_inputs, (inout0_type, "
1532
+ " inout1_type, ...))" )
1533
+ << size;
1534
+ }
1535
+ uint32_t numInputs = operands[1 ];
1536
+ SmallVector<Type, 1 > argTypes;
1537
+ SmallVector<Type, 1 > returnTypes;
1538
+ for (unsigned i = 2 ; i < size; ++i) {
1539
+ Type inOutTy = getType (operands[i]);
1540
+ if (!inOutTy) {
1541
+ return emitError (unknownLoc,
1542
+ " OpTypeGraphARM references undefined element type." )
1543
+ << operands[i];
1544
+ }
1545
+ if (i - 2 >= numInputs) {
1546
+ returnTypes.push_back (inOutTy);
1547
+ } else {
1548
+ argTypes.push_back (inOutTy);
1549
+ }
1550
+ }
1551
+ typeMap[operands[0 ]] = GraphType::get (context, argTypes, returnTypes);
1552
+ return success ();
1553
+ }
1554
+
1314
1555
LogicalResult
1315
1556
spirv::Deserializer::processTypeForwardPointer (ArrayRef<uint32_t > operands) {
1316
1557
if (operands.size () != 2 )
@@ -1823,6 +2064,34 @@ spirv::Deserializer::processConstantNull(ArrayRef<uint32_t> operands) {
1823
2064
<< resultType;
1824
2065
}
1825
2066
2067
+ LogicalResult
2068
+ spirv::Deserializer::processGraphConstantARM (ArrayRef<uint32_t > operands) {
2069
+ if (operands.size () < 3 ) {
2070
+ return emitError (unknownLoc)
2071
+ << " OpGraphConstantARM must have at least 2 operands" ;
2072
+ }
2073
+
2074
+ Type resultType = getType (operands[0 ]);
2075
+ if (!resultType) {
2076
+ return emitError (unknownLoc, " undefined result type from <id> " )
2077
+ << operands[0 ];
2078
+ }
2079
+
2080
+ uint32_t resultID = operands[1 ];
2081
+
2082
+ if (!dyn_cast<spirv::TensorArmType>(resultType)) {
2083
+ return emitError (unknownLoc, " result must be of type OpTypeTensorARM" );
2084
+ }
2085
+
2086
+ APInt graph_constant_id = APInt (32 , operands[2 ], /* isSigned=*/ true );
2087
+ Type i32Ty = opBuilder.getIntegerType (32 );
2088
+ IntegerAttr attr = opBuilder.getIntegerAttr (i32Ty, graph_constant_id);
2089
+ graphConstantMap.try_emplace (
2090
+ resultID, GraphConstantARMOpMaterializationInfo{resultType, attr});
2091
+
2092
+ return success ();
2093
+ }
2094
+
1826
2095
// ===----------------------------------------------------------------------===//
1827
2096
// Control flow
1828
2097
// ===----------------------------------------------------------------------===//
@@ -1920,6 +2189,24 @@ LogicalResult spirv::Deserializer::processLabel(ArrayRef<uint32_t> operands) {
1920
2189
return success ();
1921
2190
}
1922
2191
2192
+ LogicalResult spirv::Deserializer::createGraphBlock (uint32_t graphID) {
2193
+ if (!curGraph) {
2194
+ return emitError (unknownLoc, " a graph block must appear inside a graph" );
2195
+ }
2196
+
2197
+ // We may have forward declared this block.
2198
+ Block *block = getOrCreateBlock (graphID);
2199
+ LLVM_DEBUG (logger.startLine ()
2200
+ << " [block] populating block " << block << " \n " );
2201
+ // If we have seen this block, make sure it was just a forward declaration.
2202
+ assert (block->empty () && " re-deserialize the same block!" );
2203
+
2204
+ opBuilder.setInsertionPointToStart (block);
2205
+ blockMap[graphID] = curBlock = block;
2206
+
2207
+ return success ();
2208
+ }
2209
+
1923
2210
LogicalResult
1924
2211
spirv::Deserializer::processSelectionMerge (ArrayRef<uint32_t > operands) {
1925
2212
if (!curBlock) {
0 commit comments