Skip to content

Commit dbc4fcd

Browse files
author
Tomasz Patejko
committed
MKLDNN residual connections fuse pass: unit tests enabled and added
1 parent 4224089 commit dbc4fcd

File tree

1 file changed

+67
-70
lines changed

1 file changed

+67
-70
lines changed

paddle/fluid/framework/ir/conv_elementwise_add_mkldnn_fuse_pass_tester.cc

Lines changed: 67 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void SetOp(ProgramDesc* prog, const std::string& type,
4040
op->SetOutput(output.first, {output.second});
4141
}
4242

43-
struct IsReachable {
43+
struct TestIsReachable {
4444
using func = std::function<bool(const std::string&, const std::string&)>;
4545

4646
auto operator()(const std::unique_ptr<ir::Graph>& graph) -> func {
@@ -89,7 +89,9 @@ struct IsReachable {
8989
}
9090
};
9191

92-
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
92+
void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph,
93+
int expected_conv_count,
94+
int expected_elementwise_add_count = 0) {
9395
int conv_count = 0;
9496
int elementwise_add_count = 0;
9597

@@ -101,8 +103,8 @@ void AssertOpsCount(const std::unique_ptr<ir::Graph>& graph) {
101103
++elementwise_add_count;
102104
}
103105
}
104-
EXPECT_EQ(conv_count, 1);
105-
EXPECT_EQ(elementwise_add_count, 0);
106+
EXPECT_EQ(conv_count, expected_conv_count);
107+
EXPECT_EQ(elementwise_add_count, expected_elementwise_add_count);
106108
}
107109

108110
ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
@@ -127,117 +129,112 @@ ProgramDesc BuildProgramDesc(const std::vector<std::string>& transient_vars,
127129

128130
return prog;
129131
}
130-
} // namespace
131-
132-
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionWithElementwiseAddRelu) {
133-
auto prog =
134-
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
135-
136-
SetOp(&prog, "conv2d",
137-
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
138-
{"Output", "b"});
139-
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
140-
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
141132

142-
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
133+
void RunPassAndAssert(ProgramDesc* prog, const std::string& from,
134+
const std::string& to, int expected_conv_num) {
135+
std::unique_ptr<ir::Graph> graph(new ir::Graph(*prog));
143136

144-
IsReachable is_reachable;
145-
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
137+
TestIsReachable is_reachable;
138+
EXPECT_TRUE(is_reachable(graph)(from, to));
146139

147140
auto pass =
148141
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
149142
int original_nodes_num = graph->Nodes().size();
150143
graph = pass->Apply(std::move(graph));
151144
int current_nodes_num = graph->Nodes().size();
152145

153-
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
146+
EXPECT_TRUE(is_reachable(graph)(from, to));
154147

155148
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
156149
current_nodes_num);
157150

158-
AssertOpsCount(graph);
151+
AssertOpsCount(graph, expected_conv_num);
159152
}
153+
} // namespace
160154

161-
TEST(ConvElementwiseAddMKLDNNFusePass,
162-
ConvolutionWithElementwiseAddReluNoBias) {
163-
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
164-
SetOp(&prog, "conv2d", {{"Input", "a"}, {"Filter", "weights"}},
165-
{"Output", "b"});
166-
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
167-
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
168-
169-
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
155+
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsYWithElementwiseAddRelu) {
156+
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
170157

171-
IsReachable is_reachable;
158+
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
159+
SetOp(&prog, "conv2d",
160+
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
161+
{"Output", "c"});
172162

173-
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
163+
SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"});
164+
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
174165

175-
auto pass =
176-
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
177-
int original_nodes_num = graph->Nodes().size();
178-
graph = pass->Apply(std::move(graph));
179-
int current_nodes_num = graph->Nodes().size();
166+
RunPassAndAssert(&prog, "a", "relu", 1);
167+
}
180168

181-
EXPECT_TRUE(is_reachable(graph)("a", "relu"));
169+
TEST(ConvElementwiseAddMKLDNNFusePass,
170+
ConvolutionAsYWithElementwiseAddReluNoBias) {
171+
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
182172

183-
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
184-
current_nodes_num);
173+
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
174+
SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
175+
{"Output", "c"});
176+
SetOp(&prog, "elementwise_add", {{"X", "a"}, {"Y", "c"}}, {"Out", "d"});
177+
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
185178

186-
AssertOpsCount(graph);
179+
RunPassAndAssert(&prog, "a", "relu", 1);
187180
}
188181

189-
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionElementwiseAdd) {
190-
auto prog = BuildProgramDesc({"a", "b", "c", "d"}, {"bias", "weights"});
182+
TEST(ConvElementwiseAddMKLDNNFusePass, ConvolutionAsXWithElementwiseAddRelu) {
183+
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"bias", "weights"});
184+
185+
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
191186
SetOp(&prog, "conv2d",
192-
{{"Input", "a"}, {"Bias", "bias"}, {"Filter", "weights"}},
193-
{"Output", "b"});
194-
SetOp(&prog, "elementwise_add", {{"X", "b"}, {"Y", "c"}}, {"Out", "d"});
187+
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
188+
{"Output", "c"});
195189

196-
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
190+
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"});
191+
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
197192

198-
IsReachable is_reachable;
199-
EXPECT_TRUE(is_reachable(graph)("a", "d"));
193+
RunPassAndAssert(&prog, "a", "relu", 1);
194+
}
200195

201-
auto pass =
202-
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
203-
int original_nodes_num = graph->Nodes().size();
204-
graph = pass->Apply(std::move(graph));
205-
int current_nodes_num = graph->Nodes().size();
196+
TEST(ConvElementwiseAddMKLDNNFusePass,
197+
ConvolutionAsXWithElementwiseAddReluNoBias) {
198+
auto prog = BuildProgramDesc({"a", "b", "c", "d", "e"}, {"weights"});
206199

207-
EXPECT_FALSE(is_reachable(graph)("a", "d"));
200+
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
201+
SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
202+
{"Output", "c"});
203+
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "a"}}, {"Out", "d"});
204+
SetOp(&prog, "relu", {{"X", "d"}}, {"Out", "e"});
208205

209-
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
210-
current_nodes_num);
211-
AssertOpsCount(graph);
206+
RunPassAndAssert(&prog, "a", "relu", 1);
212207
}
213208

214-
TEST(ConvElementwiseAddMKLDNNFusePass, SigmoidConvolutionAddElementwiseRelu) {
209+
TEST(ConvElementwiseAddMKLDNNFusePass, NoFusion) {
215210
auto prog =
216-
BuildProgramDesc({"a", "b", "c", "d", "e", "f"}, {"bias", "weights"});
211+
BuildProgramDesc({"a", "b", "c", "d", "e", "f", "g"}, {"weights"});
212+
217213
SetOp(&prog, "sigmoid", {{"X", "a"}}, {"Out", "b"});
218-
SetOp(&prog, "conv2d",
219-
{{"Input", "b"}, {"Bias", "bias"}, {"Filter", "weights"}},
214+
SetOp(&prog, "conv2d", {{"Input", "b"}, {"Filter", "weights"}},
220215
{"Output", "c"});
221-
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "d"}}, {"Out", "e"});
222-
SetOp(&prog, "relu", {{"X", "e"}}, {"Out", "f"});
223216

224-
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
217+
SetOp(&prog, "conv2d", {{"Input", "d"}, {"Filter", "weights"}},
218+
{"Output", "e"});
225219

226-
IsReachable is_reachable;
220+
SetOp(&prog, "elementwise_add", {{"X", "c"}, {"Y", "e"}}, {"Out", "f"});
221+
SetOp(&prog, "relu", {{"X", "f"}}, {"Out", "g"});
227222

228-
EXPECT_TRUE(is_reachable(graph)("a", "f"));
223+
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
224+
225+
TestIsReachable is_reachable;
226+
EXPECT_TRUE(is_reachable(graph)("a", "g"));
229227

230228
auto pass =
231229
PassRegistry::Instance().Get("conv_elementwise_add_mkldnn_fuse_pass");
232230
int original_nodes_num = graph->Nodes().size();
233231
graph = pass->Apply(std::move(graph));
234232
int current_nodes_num = graph->Nodes().size();
235233

236-
EXPECT_TRUE(is_reachable(graph)("a", "f"));
234+
EXPECT_TRUE(is_reachable(graph)("a", "g"));
235+
EXPECT_EQ(original_nodes_num, current_nodes_num);
237236

238-
EXPECT_EQ(original_nodes_num - nodes_removed + nodes_added,
239-
current_nodes_num);
240-
AssertOpsCount(graph);
237+
AssertOpsCount(graph, 2, 1);
241238
}
242239

243240
} // namespace ir

0 commit comments

Comments
 (0)