Skip to content

Commit 35ff9af

Browse files
committed
Merge pull request opencv#21162 from rogday:softmax_simplification
2 parents dad2b9a + 8294107 commit 35ff9af

File tree

1 file changed

+59
-11
lines changed

1 file changed

+59
-11
lines changed

modules/dnn/src/onnx/onnx_graph_simplifier.cpp

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -107,25 +107,19 @@ class ONNXGraphWrapper : public ImportGraphWrapper
107107
opencv_onnx::GraphProto& net;
108108
};
109109

110-
class SoftMaxSubgraph : public Subgraph
110+
class SoftMaxSubgraphBase : public Subgraph
111111
{
112112
public:
113-
SoftMaxSubgraph() : axis(1)
114-
{
115-
int input = addNodeToMatch("");
116-
int inpExp = addNodeToMatch("Exp", input);
117-
int sum = addNodeToMatch("ReduceSum", inpExp);
118-
addNodeToMatch("Div", inpExp, sum);
119-
setFusedNode("Softmax", input);
120-
}
113+
SoftMaxSubgraphBase() : axis(1), id(-1) {}
121114

122115
virtual bool match(const Ptr<ImportGraphWrapper>& net, int nodeId,
123116
std::vector<int>& matchedNodesIds,
124117
std::vector<int>& targetNodesIds) CV_OVERRIDE
125118
{
126119
if (Subgraph::match(net, nodeId, matchedNodesIds, targetNodesIds))
127120
{
128-
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[1]);
121+
CV_Assert(id >= 0 && id < matchedNodesIds.size());
122+
Ptr<ImportNodeWrapper> sum = net->getNode(matchedNodesIds[id]);
129123
opencv_onnx::NodeProto* node = sum.dynamicCast<ONNXNodeWrapper>()->node;
130124

131125
for (int i = 0; i < node->attribute_size(); i++)
@@ -153,8 +147,60 @@ class SoftMaxSubgraph : public Subgraph
153147
attr->set_i(axis);
154148
}
155149

156-
private:
150+
protected:
157151
int axis;
152+
int id;
153+
};
154+
155+
class SoftMaxSubgraph : public SoftMaxSubgraphBase
156+
{
157+
public:
158+
SoftMaxSubgraph()
159+
{
160+
int input = addNodeToMatch("");
161+
int inpExp = addNodeToMatch("Exp", input);
162+
163+
int sum = addNodeToMatch("ReduceSum", inpExp);
164+
id = 1;
165+
166+
addNodeToMatch("Div", inpExp, sum);
167+
setFusedNode("Softmax", input);
168+
}
169+
};
170+
171+
class SoftMaxSubgraph2 : public SoftMaxSubgraphBase {
172+
public:
173+
SoftMaxSubgraph2() {
174+
int input = addNodeToMatch("");
175+
176+
int reducemax = addNodeToMatch("ReduceMax", input);
177+
id = 0;
178+
179+
int sub = addNodeToMatch("Sub", input, reducemax);
180+
int exp = addNodeToMatch("Exp", sub);
181+
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
182+
addNodeToMatch("Div", exp, reducesum);
183+
setFusedNode("Softmax", input);
184+
}
185+
};
186+
187+
class LogSoftMaxSubgraph : public SoftMaxSubgraphBase
188+
{
189+
public:
190+
LogSoftMaxSubgraph()
191+
{
192+
int input = addNodeToMatch("");
193+
194+
int reducemax = addNodeToMatch("ReduceMax", input);
195+
id = 0;
196+
197+
int sub_1 = addNodeToMatch("Sub", input, reducemax);
198+
int exp = addNodeToMatch("Exp", sub_1);
199+
int reducesum = addNodeToMatch("ReduceSum", exp, addNodeToMatch(""));
200+
int log = addNodeToMatch("Log", reducesum);
201+
addNodeToMatch("Sub", sub_1, log);
202+
setFusedNode("LogSoftmax", input);
203+
}
158204
};
159205

160206
class NormalizeSubgraphBase : public Subgraph
@@ -574,6 +620,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
574620
subgraphs.push_back(makePtr<ResizeSubgraph1>());
575621
subgraphs.push_back(makePtr<ResizeSubgraph2>());
576622
subgraphs.push_back(makePtr<SoftMaxSubgraph>());
623+
subgraphs.push_back(makePtr<SoftMaxSubgraph2>());
624+
subgraphs.push_back(makePtr<LogSoftMaxSubgraph>());
577625
subgraphs.push_back(makePtr<NormalizeSubgraph1>());
578626
subgraphs.push_back(makePtr<NormalizeSubgraph2>());
579627
subgraphs.push_back(makePtr<NormalizeSubgraph2_2>());

0 commit comments

Comments
 (0)