@@ -40,7 +40,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
40
40
op->SetOutput (output.first , {output.second });
41
41
}
42
42
43
- struct IsReachable {
43
+ struct TestIsReachable {
44
44
using func = std::function<bool (const std::string&, const std::string&)>;
45
45
46
46
auto operator ()(const std::unique_ptr<ir::Graph>& graph) -> func {
@@ -89,7 +89,9 @@ struct IsReachable {
89
89
}
90
90
};
91
91
92
- void AssertOpsCount (const std::unique_ptr<ir::Graph>& graph) {
92
+ void AssertOpsCount (const std::unique_ptr<ir::Graph>& graph,
93
+ int expected_conv_count,
94
+ int expected_elementwise_add_count = 0 ) {
93
95
int conv_count = 0 ;
94
96
int elementwise_add_count = 0 ;
95
97
@@ -101,8 +103,8 @@ void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
101
103
++elementwise_add_count;
102
104
}
103
105
}
104
- EXPECT_EQ (conv_count, 1 );
105
- EXPECT_EQ (elementwise_add_count, 0 );
106
+ EXPECT_EQ (conv_count, expected_conv_count );
107
+ EXPECT_EQ (elementwise_add_count, expected_elementwise_add_count );
106
108
}
107
109
108
110
ProgramDesc BuildProgramDesc (const std::vector<std::string>& transient_vars,
@@ -127,117 +129,112 @@ ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
127
129
128
130
return prog;
129
131
}
130
- } // namespace
131
-
132
- TEST (ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
133
- auto prog =
134
- BuildProgramDesc ({" a" , " b" , " c" , " d" , " e" , " f" }, {" bias" , " weights" });
135
-
136
- SetOp (&prog, " conv2d" ,
137
- {{" Input" , " a" }, {" Bias" , " bias" }, {" Filter" , " weights" }},
138
- {" Output" , " b" });
139
- SetOp (&prog, " elementwise_add" , {{" X" , " b" }, {" Y" , " c" }}, {" Out" , " d" });
140
- SetOp (&prog, " relu" , {{" X" , " d" }}, {" Out" , " e" });
141
132
142
- std::unique_ptr<ir::Graph> graph (new ir::Graph (prog));
133
+ void RunPassAndAssert (ProgramDesc* prog, const std::string& from,
134
+ const std::string& to, int expected_conv_num) {
135
+ std::unique_ptr<ir::Graph> graph (new ir::Graph (*prog));
143
136
144
- IsReachable is_reachable;
145
- EXPECT_TRUE (is_reachable (graph)(" a " , " relu " ));
137
+ TestIsReachable is_reachable;
138
+ EXPECT_TRUE (is_reachable (graph)(from, to ));
146
139
147
140
auto pass =
148
141
PassRegistry::Instance ().Get (" conv_elementwise_add_mkldnn_fuse_pass" );
149
142
int original_nodes_num = graph->Nodes ().size ();
150
143
graph = pass->Apply (std::move (graph));
151
144
int current_nodes_num = graph->Nodes ().size ();
152
145
153
- EXPECT_TRUE (is_reachable (graph)(" a " , " relu " ));
146
+ EXPECT_TRUE (is_reachable (graph)(from, to ));
154
147
155
148
EXPECT_EQ (original_nodes_num - nodes_removed + nodes_added,
156
149
current_nodes_num);
157
150
158
- AssertOpsCount (graph);
151
+ AssertOpsCount (graph, expected_conv_num );
159
152
}
153
+ } // namespace
160
154
161
- TEST (ConvElementwiseAddMKLDNNFusePass,
162
- ConvolutionWithElementwiseAddReluNoBias) {
163
- auto prog = BuildProgramDesc ({" a" , " b" , " c" , " d" , " e" }, {" weights" });
164
- SetOp (&prog, " conv2d" , {{" Input" , " a" }, {" Filter" , " weights" }},
165
- {" Output" , " b" });
166
- SetOp (&prog, " elementwise_add" , {{" X" , " b" }, {" Y" , " c" }}, {" Out" , " d" });
167
- SetOp (&prog, " relu" , {{" X" , " d" }}, {" Out" , " e" });
168
-
169
- std::unique_ptr<ir::Graph> graph (new ir::Graph (prog));
155
+ TEST (ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) {
156
+ auto prog = BuildProgramDesc ({" a" , " b" , " c" , " d" , " e" }, {" bias" , " weights" });
170
157
171
- IsReachable is_reachable;
158
+ SetOp (&prog, " sigmoid" , {{" X" , " a" }}, {" Out" , " b" });
159
+ SetOp (&prog, " conv2d" ,
160
+ {{" Input" , " b" }, {" Bias" , " bias" }, {" Filter" , " weights" }},
161
+ {" Output" , " c" });
172
162
173
- EXPECT_TRUE (is_reachable (graph)(" a" , " relu" ));
163
+ SetOp (&prog, " elementwise_add" , {{" X" , " a" }, {" Y" , " c" }}, {" Out" , " d" });
164
+ SetOp (&prog, " relu" , {{" X" , " d" }}, {" Out" , " e" });
174
165
175
- auto pass =
176
- PassRegistry::Instance ().Get (" conv_elementwise_add_mkldnn_fuse_pass" );
177
- int original_nodes_num = graph->Nodes ().size ();
178
- graph = pass->Apply (std::move (graph));
179
- int current_nodes_num = graph->Nodes ().size ();
166
+ RunPassAndAssert (&prog, " a" , " relu" , 1 );
167
+ }
180
168
181
- EXPECT_TRUE (is_reachable (graph)(" a" , " relu" ));
169
+ TEST (ConvElementwiseAddMKLDNNFusePass,
170
+ ConvolutionAsYWithElementwiseAddReluNoBias) {
171
+ auto prog = BuildProgramDesc ({" a" , " b" , " c" , " d" , " e" }, {" weights" });
182
172
183
- EXPECT_EQ (original_nodes_num - nodes_removed + nodes_added,
184
- current_nodes_num);
173
+ SetOp (&prog, " sigmoid" , {{" X" , " a" }}, {" Out" , " b" });
174
+ SetOp (&prog, " conv2d" , {{" Input" , " b" }, {" Filter" , " weights" }},
175
+ {" Output" , " c" });
176
+ SetOp (&prog, " elementwise_add" , {{" X" , " a" }, {" Y" , " c" }}, {" Out" , " d" });
177
+ SetOp (&prog, " relu" , {{" X" , " d" }}, {" Out" , " e" });
185
178
186
- AssertOpsCount (graph );
179
+ RunPassAndAssert (&prog, " a " , " relu " , 1 );
187
180
}
188
181
189
- TEST (ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
190
- auto prog = BuildProgramDesc ({" a" , " b" , " c" , " d" }, {" bias" , " weights" });
182
+ TEST (ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
183
+ auto prog = BuildProgramDesc ({" a" , " b" , " c" , " d" , " e" }, {" bias" , " weights" });
184
+
185
+ SetOp (&prog, " sigmoid" , {{" X" , " a" }}, {" Out" , " b" });
191
186
SetOp (&prog, " conv2d" ,
192
- {{" Input" , " a" }, {" Bias" , " bias" }, {" Filter" , " weights" }},
193
- {" Output" , " b" });
194
- SetOp (&prog, " elementwise_add" , {{" X" , " b" }, {" Y" , " c" }}, {" Out" , " d" });
187
+ {{" Input" , " b" }, {" Bias" , " bias" }, {" Filter" , " weights" }},
188
+ {" Output" , " c" });
195
189
196
- std::unique_ptr<ir::Graph> graph (new ir::Graph (prog));
190
+ SetOp (&prog, " elementwise_add" , {{" X" , " c" }, {" Y" , " a" }}, {" Out" , " d" });
191
+ SetOp (&prog, " relu" , {{" X" , " d" }}, {" Out" , " e" });
197
192
198
- IsReachable is_reachable ;
199
- EXPECT_TRUE ( is_reachable (graph)( " a " , " d " ));
193
+ RunPassAndAssert (&prog, " a " , " relu " , 1 ) ;
194
+ }
200
195
201
- auto pass =
202
- PassRegistry::Instance ().Get (" conv_elementwise_add_mkldnn_fuse_pass" );
203
- int original_nodes_num = graph->Nodes ().size ();
204
- graph = pass->Apply (std::move (graph));
205
- int current_nodes_num = graph->Nodes ().size ();
196
+ TEST (ConvElementwiseAddMKLDNNFusePass,
197
+ ConvolutionAsXWithElementwiseAddReluNoBias) {
198
+ auto prog = BuildProgramDesc ({" a" , " b" , " c" , " d" , " e" }, {" weights" });
206
199
207
- EXPECT_FALSE (is_reachable (graph)(" a" , " d" ));
200
+ SetOp (&prog, " sigmoid" , {{" X" , " a" }}, {" Out" , " b" });
201
+ SetOp (&prog, " conv2d" , {{" Input" , " b" }, {" Filter" , " weights" }},
202
+ {" Output" , " c" });
203
+ SetOp (&prog, " elementwise_add" , {{" X" , " c" }, {" Y" , " a" }}, {" Out" , " d" });
204
+ SetOp (&prog, " relu" , {{" X" , " d" }}, {" Out" , " e" });
208
205
209
- EXPECT_EQ (original_nodes_num - nodes_removed + nodes_added,
210
- current_nodes_num);
211
- AssertOpsCount (graph);
206
+ RunPassAndAssert (&prog, " a" , " relu" , 1 );
212
207
}
213
208
214
- TEST (ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu ) {
209
+ TEST (ConvElementwiseAddMKLDNNFusePass, NoFusion ) {
215
210
auto prog =
216
- BuildProgramDesc ({" a" , " b" , " c" , " d" , " e" , " f" }, {" bias" , " weights" });
211
+ BuildProgramDesc ({" a" , " b" , " c" , " d" , " e" , " f" , " g" }, {" weights" });
212
+
217
213
SetOp (&prog, " sigmoid" , {{" X" , " a" }}, {" Out" , " b" });
218
- SetOp (&prog, " conv2d" ,
219
- {{" Input" , " b" }, {" Bias" , " bias" }, {" Filter" , " weights" }},
214
+ SetOp (&prog, " conv2d" , {{" Input" , " b" }, {" Filter" , " weights" }},
220
215
{" Output" , " c" });
221
- SetOp (&prog, " elementwise_add" , {{" X" , " c" }, {" Y" , " d" }}, {" Out" , " e" });
222
- SetOp (&prog, " relu" , {{" X" , " e" }}, {" Out" , " f" });
223
216
224
- std::unique_ptr<ir::Graph> graph (new ir::Graph (prog));
217
+ SetOp (&prog, " conv2d" , {{" Input" , " d" }, {" Filter" , " weights" }},
218
+ {" Output" , " e" });
225
219
226
- IsReachable is_reachable;
220
+ SetOp (&prog, " elementwise_add" , {{" X" , " c" }, {" Y" , " e" }}, {" Out" , " f" });
221
+ SetOp (&prog, " relu" , {{" X" , " f" }}, {" Out" , " g" });
227
222
228
- EXPECT_TRUE (is_reachable (graph)(" a" , " f" ));
223
+ std::unique_ptr<ir::Graph> graph (new ir::Graph (prog));
224
+
225
+ TestIsReachable is_reachable;
226
+ EXPECT_TRUE (is_reachable (graph)(" a" , " g" ));
229
227
230
228
auto pass =
231
229
PassRegistry::Instance ().Get (" conv_elementwise_add_mkldnn_fuse_pass" );
232
230
int original_nodes_num = graph->Nodes ().size ();
233
231
graph = pass->Apply (std::move (graph));
234
232
int current_nodes_num = graph->Nodes ().size ();
235
233
236
- EXPECT_TRUE (is_reachable (graph)(" a" , " f" ));
234
+ EXPECT_TRUE (is_reachable (graph)(" a" , " g" ));
235
+ EXPECT_EQ (original_nodes_num, current_nodes_num);
237
236
238
- EXPECT_EQ (original_nodes_num - nodes_removed + nodes_added,
239
- current_nodes_num);
240
- AssertOpsCount (graph);
237
+ AssertOpsCount (graph, 2 , 1 );
241
238
}
242
239
243
240
} // namespace ir
0 commit comments