@@ -22,29 +22,22 @@ namespace paddle {
22
22
namespace framework {
23
23
namespace ir {
24
24
25
+ namespace {
25
26
constexpr int nodes_removed = 3 ;
26
27
constexpr int nodes_added = 1 ;
27
28
28
29
void SetOp (ProgramDesc* prog, const std::string& type,
29
- const std::vector<std::string>& inputs,
30
- const std::vector <std::string>& outputs ) {
30
+ const std::vector<std::pair<std:: string, std::string> >& inputs,
31
+ const std::pair <std::string, std::string >& output ) {
31
32
auto op = prog->MutableBlock (0 )->AppendOp ();
32
33
op->SetType (type);
34
+ op->SetAttr (" use_mkldnn" , true );
33
35
34
- if (type == " conv2d" ) {
35
- op->SetAttr (" use_mkldnn" , true );
36
- op->SetInput (" Input" , {inputs[0 ]});
37
- op->SetInput (" Bias" , {inputs[1 ]});
38
- op->SetInput (" Filter" , {inputs[2 ]});
39
- op->SetOutput (" Output" , outputs);
40
- } else if (type == " elementwise_add" ) {
41
- op->SetInput (" X" , {inputs[0 ]});
42
- op->SetInput (" Y" , {inputs[1 ]});
43
- op->SetOutput (" Out" , outputs);
44
- } else if (type == " relu" || type == " sigmoid" ) {
45
- op->SetInput (" X" , {inputs[0 ]});
46
- op->SetOutput (" Out" , outputs);
36
+ for (const auto & input : inputs) {
37
+ op->SetInput (input.first , {input.second });
47
38
}
39
+
40
+ op->SetOutput (output.first , {output.second });
48
41
}
49
42
50
43
struct IsReachable {
@@ -96,30 +89,59 @@ struct IsReachable {
96
89
}
97
90
};
98
91
99
- TEST (ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu ) {
100
- auto build_program_desc = [&]() -> ProgramDesc {
101
- ProgramDesc prog ;
102
- for ( auto & v : std::vector<std::string>(
103
- { " a " , " b " , " bias " , " weights " , " c " , " d " , " e " , " f " } )) {
104
- auto * var = prog. MutableBlock ( 0 )->Var (v);
105
- var-> SetType (proto::VarType::LOD_TENSOR) ;
106
- if (v == " weights " || v == " bias " ) {
107
- var-> SetPersistable ( true );
108
- }
92
+ void AssertOpsCount ( const std::unique_ptr<ir::Graph>& graph ) {
93
+ int conv_count = 0 ;
94
+ int elementwise_add_count = 0 ;
95
+
96
+ for ( auto * node : graph-> Nodes ( )) {
97
+ if (node-> IsOp () && node-> Op ( )->Type () == " conv2d " ) {
98
+ ++conv_count ;
99
+ }
100
+ if (node-> IsOp () && node-> Op ()-> Type () == " elementwise_add " ) {
101
+ ++elementwise_add_count;
109
102
}
103
+ }
104
+ EXPECT_EQ (conv_count, 1 );
105
+ EXPECT_EQ (elementwise_add_count, 0 );
106
+ }
107
+
108
+ ProgramDesc BuildProgramDesc (const std::vector<std::string>& transient_vars,
109
+ const std::vector<std::string>& persistent_vars) {
110
+ ProgramDesc prog;
110
111
111
- SetOp ( &prog, " conv2d " , { " a " , " bias " , " weights " }, { " b " });
112
- SetOp (&prog, " elementwise_add " , { " b " , " c " }, { " d " } );
113
- SetOp (&prog, " relu " , { " d " }, { " e " } );
112
+ auto add_var_to_prog = [ &prog]( const std::string& var_name) -> VarDesc* {
113
+ auto var = prog. MutableBlock ( 0 )-> Var (var_name );
114
+ var-> SetType (proto::VarType::LOD_TENSOR );
114
115
115
- return prog ;
116
+ return var ;
116
117
};
117
118
118
- auto prog = build_program_desc ();
119
+ for (const auto & v : transient_vars) {
120
+ add_var_to_prog (v);
121
+ }
122
+
123
+ for (const auto & v : persistent_vars) {
124
+ auto var = add_var_to_prog (v);
125
+ var->SetPersistable (true );
126
+ }
127
+
128
+ return prog;
129
+ }
130
+ } // namespace
131
+
132
+ TEST (ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
133
+ auto prog =
134
+ BuildProgramDesc ({" a" , " b" , " c" , " d" , " e" , " f" }, {" bias" , " weights" });
135
+
136
+ SetOp (&prog, " conv2d" ,
137
+ {{" Input" , " a" }, {" Bias" , " bias" }, {" Filter" , " weights" }},
138
+ {" Output" , " b" });
139
+ SetOp (&prog, " elementwise_add" , {{" X" , " b" }, {" Y" , " c" }}, {" Out" , " d" });
140
+ SetOp (&prog, " relu" , {{" X" , " d" }}, {" Out" , " e" });
141
+
119
142
std::unique_ptr<ir::Graph> graph (new ir::Graph (prog));
120
143
121
144
IsReachable is_reachable;
122
-
123
145
EXPECT_TRUE (is_reachable (graph)(" a" , " relu" ));
124
146
125
147
auto pass =
@@ -132,40 +154,45 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
132
154
133
155
EXPECT_EQ (original_nodes_num - nodes_removed + nodes_added,
134
156
current_nodes_num);
135
- // Assert conv_relu op in newly generated graph
136
- int conv_count = 0 ;
137
- int elementwise_add_count = 0 ;
138
157
139
- for (auto * node : graph->Nodes ()) {
140
- if (node->IsOp () && node->Op ()->Type () == " conv2d" ) {
141
- ++conv_count;
142
- }
143
- if (node->IsOp () && node->Op ()->Type () == " elementwise_add" ) {
144
- ++elementwise_add_count;
145
- }
146
- }
147
- EXPECT_EQ (conv_count, 1 );
148
- EXPECT_EQ (elementwise_add_count, 0 );
158
+ AssertOpsCount (graph);
149
159
}
150
160
151
- TEST (ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
152
- auto build_program_desc = [&]() -> ProgramDesc {
153
- ProgramDesc prog;
154
- for (auto & v : std::vector<std::string>({" a" , " b" , " bias" , " weights" })) {
155
- auto * var = prog.MutableBlock (0 )->Var (v);
156
- var->SetType (proto::VarType::LOD_TENSOR);
157
- if (v == " weights" || v == " bias" ) {
158
- var->SetPersistable (true );
159
- }
160
- }
161
+ TEST (ConvElementwiseAddMKLDNNFusePass,
162
+ ConvolutionWithElementwiseAddReluNoBias) {
163
+ auto prog = BuildProgramDesc ({" a" , " b" , " c" , " d" , " e" }, {" weights" });
164
+ SetOp (&prog, " conv2d" , {{" Input" , " a" }, {" Filter" , " weights" }},
165
+ {" Output" , " b" });
166
+ SetOp (&prog, " elementwise_add" , {{" X" , " b" }, {" Y" , " c" }}, {" Out" , " d" });
167
+ SetOp (&prog, " relu" , {{" X" , " d" }}, {" Out" , " e" });
161
168
162
- SetOp (&prog, " conv2d" , {" a" , " bias" , " weights" }, {" b" });
163
- SetOp (&prog, " elementwise_add" , {" b" , " c" }, {" d" });
169
+ std::unique_ptr<ir::Graph> graph (new ir::Graph (prog));
164
170
165
- return prog;
166
- };
171
+ IsReachable is_reachable;
172
+
173
+ EXPECT_TRUE (is_reachable (graph)(" a" , " relu" ));
174
+
175
+ auto pass =
176
+ PassRegistry::Instance ().Get (" conv_elementwise_add_mkldnn_fuse_pass" );
177
+ int original_nodes_num = graph->Nodes ().size ();
178
+ graph = pass->Apply (std::move (graph));
179
+ int current_nodes_num = graph->Nodes ().size ();
180
+
181
+ EXPECT_TRUE (is_reachable (graph)(" a" , " relu" ));
182
+
183
+ EXPECT_EQ (original_nodes_num - nodes_removed + nodes_added,
184
+ current_nodes_num);
185
+
186
+ AssertOpsCount (graph);
187
+ }
188
+
189
+ TEST (ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
190
+ auto prog = BuildProgramDesc ({" a" , " b" , " c" , " d" }, {" bias" , " weights" });
191
+ SetOp (&prog, " conv2d" ,
192
+ {{" Input" , " a" }, {" Bias" , " bias" }, {" Filter" , " weights" }},
193
+ {" Output" , " b" });
194
+ SetOp (&prog, " elementwise_add" , {{" X" , " b" }, {" Y" , " c" }}, {" Out" , " d" });
167
195
168
- auto prog = build_program_desc ();
169
196
std::unique_ptr<ir::Graph> graph (new ir::Graph (prog));
170
197
171
198
IsReachable is_reachable;
@@ -181,43 +208,19 @@ TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
181
208
182
209
EXPECT_EQ (original_nodes_num - nodes_removed + nodes_added,
183
210
current_nodes_num);
184
- // Assert conv_relu op in newly generated graph
185
- int conv_count = 0 ;
186
- int elementwise_add_count = 0 ;
187
-
188
- for (auto * node : graph->Nodes ()) {
189
- if (node->IsOp () && node->Op ()->Type () == " conv2d" ) {
190
- ++conv_count;
191
- }
192
- if (node->IsOp () && node->Op ()->Type () == " elementwise_add" ) {
193
- ++elementwise_add_count;
194
- }
195
- }
196
- EXPECT_EQ (conv_count, 1 );
197
- EXPECT_EQ (elementwise_add_count, 0 );
211
+ AssertOpsCount (graph);
198
212
}
199
213
200
214
TEST (ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
201
- auto build_program_desc = [&]() -> ProgramDesc {
202
- ProgramDesc prog;
203
- for (auto & v : std::vector<std::string>(
204
- {" a" , " b" , " bias" , " weights" , " c" , " d" , " e" , " f" })) {
205
- auto * var = prog.MutableBlock (0 )->Var (v);
206
- var->SetType (proto::VarType::LOD_TENSOR);
207
- if (v.find (" weights" ) || v.find (" bias" )) {
208
- var->SetPersistable (true );
209
- }
210
- }
211
-
212
- SetOp (&prog, " sigmoid" , {" a" }, {" b" });
213
- SetOp (&prog, " conv2d" , {" b" , " bias" , " weights" }, {" c" });
214
- SetOp (&prog, " elementwise_add" , {" c" , " d" }, {" e" });
215
- SetOp (&prog, " relu" , {" e" }, {" f" });
216
-
217
- return prog;
218
- };
215
+ auto prog =
216
+ BuildProgramDesc ({" a" , " b" , " c" , " d" , " e" , " f" }, {" bias" , " weights" });
217
+ SetOp (&prog, " sigmoid" , {{" X" , " a" }}, {" Out" , " b" });
218
+ SetOp (&prog, " conv2d" ,
219
+ {{" Input" , " b" }, {" Bias" , " bias" }, {" Filter" , " weights" }},
220
+ {" Output" , " c" });
221
+ SetOp (&prog, " elementwise_add" , {{" X" , " c" }, {" Y" , " d" }}, {" Out" , " e" });
222
+ SetOp (&prog, " relu" , {{" X" , " e" }}, {" Out" , " f" });
219
223
220
- auto prog = build_program_desc ();
221
224
std::unique_ptr<ir::Graph> graph (new ir::Graph (prog));
222
225
223
226
IsReachable is_reachable;
@@ -234,20 +237,7 @@ TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
234
237
235
238
EXPECT_EQ (original_nodes_num - nodes_removed + nodes_added,
236
239
current_nodes_num);
237
- // Assert conv_relu op in newly generated graph
238
- int conv_count = 0 ;
239
- int elementwise_add_count = 0 ;
240
-
241
- for (auto * node : graph->Nodes ()) {
242
- if (node->IsOp () && node->Op ()->Type () == " conv2d" ) {
243
- ++conv_count;
244
- }
245
- if (node->IsOp () && node->Op ()->Type () == " elementwise_add" ) {
246
- ++elementwise_add_count;
247
- }
248
- }
249
- EXPECT_EQ (conv_count, 1 );
250
- EXPECT_EQ (elementwise_add_count, 0 );
240
+ AssertOpsCount (graph);
251
241
}
252
242
253
243
} // namespace ir
0 commit comments