@@ -107,25 +107,19 @@ class ONNXGraphWrapper : public ImportGraphWrapper
107107 opencv_onnx::GraphProto& net;
108108};
109109
110- class SoftMaxSubgraph : public Subgraph
110+ class SoftMaxSubgraphBase : public Subgraph
111111{
112112public:
113- SoftMaxSubgraph () : axis(1 )
114- {
115- int input = addNodeToMatch (" " );
116- int inpExp = addNodeToMatch (" Exp" , input);
117- int sum = addNodeToMatch (" ReduceSum" , inpExp);
118- addNodeToMatch (" Div" , inpExp, sum);
119- setFusedNode (" Softmax" , input);
120- }
113+ SoftMaxSubgraphBase () : axis(1 ), id(-1 ) {}
121114
122115 virtual bool match (const Ptr<ImportGraphWrapper>& net, int nodeId,
123116 std::vector<int >& matchedNodesIds,
124117 std::vector<int >& targetNodesIds) CV_OVERRIDE
125118 {
126119 if (Subgraph::match (net, nodeId, matchedNodesIds, targetNodesIds))
127120 {
128- Ptr<ImportNodeWrapper> sum = net->getNode (matchedNodesIds[1 ]);
121+ CV_Assert (id >= 0 && id < matchedNodesIds.size ());
122+ Ptr<ImportNodeWrapper> sum = net->getNode (matchedNodesIds[id]);
129123 opencv_onnx::NodeProto* node = sum.dynamicCast <ONNXNodeWrapper>()->node ;
130124
131125 for (int i = 0 ; i < node->attribute_size (); i++)
@@ -153,8 +147,60 @@ class SoftMaxSubgraph : public Subgraph
153147 attr->set_i (axis);
154148 }
155149
156- private :
150+ protected :
157151 int axis;
152+ int id;
153+ };
154+
155+ class SoftMaxSubgraph : public SoftMaxSubgraphBase
156+ {
157+ public:
158+ SoftMaxSubgraph ()
159+ {
160+ int input = addNodeToMatch (" " );
161+ int inpExp = addNodeToMatch (" Exp" , input);
162+
163+ int sum = addNodeToMatch (" ReduceSum" , inpExp);
164+ id = 1 ;
165+
166+ addNodeToMatch (" Div" , inpExp, sum);
167+ setFusedNode (" Softmax" , input);
168+ }
169+ };
170+
171+ class SoftMaxSubgraph2 : public SoftMaxSubgraphBase {
172+ public:
173+ SoftMaxSubgraph2 () {
174+ int input = addNodeToMatch (" " );
175+
176+ int reducemax = addNodeToMatch (" ReduceMax" , input);
177+ id = 0 ;
178+
179+ int sub = addNodeToMatch (" Sub" , input, reducemax);
180+ int exp = addNodeToMatch (" Exp" , sub);
181+ int reducesum = addNodeToMatch (" ReduceSum" , exp, addNodeToMatch (" " ));
182+ addNodeToMatch (" Div" , exp, reducesum);
183+ setFusedNode (" Softmax" , input);
184+ }
185+ };
186+
187+ class LogSoftMaxSubgraph : public SoftMaxSubgraphBase
188+ {
189+ public:
190+ LogSoftMaxSubgraph ()
191+ {
192+ int input = addNodeToMatch (" " );
193+
194+ int reducemax = addNodeToMatch (" ReduceMax" , input);
195+ id = 0 ;
196+
197+ int sub_1 = addNodeToMatch (" Sub" , input, reducemax);
198+ int exp = addNodeToMatch (" Exp" , sub_1);
199+ int reducesum = addNodeToMatch (" ReduceSum" , exp, addNodeToMatch (" " ));
200+ int log = addNodeToMatch (" Log" , reducesum);
201+ addNodeToMatch (" Sub" , sub_1, log);
202+ setFusedNode (" LogSoftmax" , input);
203+ }
158204};
159205
160206class NormalizeSubgraphBase : public Subgraph
@@ -574,6 +620,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
574620 subgraphs.push_back (makePtr<ResizeSubgraph1>());
575621 subgraphs.push_back (makePtr<ResizeSubgraph2>());
576622 subgraphs.push_back (makePtr<SoftMaxSubgraph>());
623+ subgraphs.push_back (makePtr<SoftMaxSubgraph2>());
624+ subgraphs.push_back (makePtr<LogSoftMaxSubgraph>());
577625 subgraphs.push_back (makePtr<NormalizeSubgraph1>());
578626 subgraphs.push_back (makePtr<NormalizeSubgraph2>());
579627 subgraphs.push_back (makePtr<NormalizeSubgraph2_2>());
0 commit comments