Skip to content

Commit d000008

Browse files
authored
Merge pull request #13552 from sfraczek/sfraczek/conv-relu-update
little update to conv relu fuse pass (fix)
2 parents cc20867 + e5d1bd1 commit d000008

File tree

4 files changed

+17
-45
lines changed

4 files changed

+17
-45
lines changed

paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass.cc

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
2626
PADDLE_ENFORCE(graph.get());
2727
FusePassBase::Init("conv_relu_mkldnn_fuse", graph.get());
2828

29-
std::unordered_set<Node*> nodes2delete;
30-
3129
GraphPatternDetector gpd;
3230
auto* conv_input = gpd.mutable_pattern()
3331
->NewNode("conv_relu_mkldnn_fuse/conv_input")
@@ -42,36 +40,20 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
4240
Graph* g) {
4341
VLOG(4) << "handle ConvReLU fuse";
4442
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight,
45-
conv_relu_pattern); // Filter
46-
GET_IR_NODE_FROM_SUBGRAPH(conv_bias, conv_bias, conv_relu_pattern); // Bias
47-
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
43+
conv_relu_pattern); // Filter
44+
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, conv_relu_pattern); // tmp
4845
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, conv_relu_pattern); // CONV op
4946
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, conv_relu_pattern); // Out
5047
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, conv_relu_pattern); // ReLU op
5148

52-
// Create an ConvReLU Node.
53-
OpDesc desc;
54-
std::string conv_relu_i_in = subgraph.at(conv_input)->Name();
55-
std::string conv_relu_w_in = conv_weight->Name();
56-
std::string conv_relu_b_in = conv_bias->Name();
57-
std::string conv_relu_out = relu_out->Name();
58-
desc.SetInput("Input", std::vector<std::string>({conv_relu_i_in}));
59-
desc.SetInput("Filter", std::vector<std::string>({conv_relu_w_in}));
60-
desc.SetInput("Bias", std::vector<std::string>({conv_relu_b_in}));
61-
desc.SetOutput("Output", std::vector<std::string>({conv_relu_out}));
62-
desc.SetType("conv2d");
63-
for (auto& attr : conv->Op()->GetAttrMap()) {
64-
desc.SetAttr(attr.first, attr.second);
65-
}
66-
desc.SetAttr("fuse_relu", true);
67-
auto conv_relu_node = g->CreateOpNode(&desc); // OpDesc will be copied.
68-
GraphSafeRemoveNodes(graph.get(), {conv, relu, conv_out});
49+
// Transform Conv node into ConvReLU node.
50+
OpDesc* desc = conv->Op();
51+
desc->SetOutput("Output", std::vector<std::string>({relu_out->Name()}));
52+
desc->SetAttr("fuse_relu", true);
53+
GraphSafeRemoveNodes(graph.get(), {relu, conv_out});
6954

7055
PADDLE_ENFORCE(subgraph.count(conv_input));
71-
IR_NODE_LINK_TO(subgraph.at(conv_input), conv_relu_node);
72-
IR_NODE_LINK_TO(conv_weight, conv_relu_node);
73-
IR_NODE_LINK_TO(conv_bias, conv_relu_node);
74-
IR_NODE_LINK_TO(conv_relu_node, relu_out);
56+
IR_NODE_LINK_TO(conv, relu_out);
7557

7658
found_conv_relu_count++;
7759
};

paddle/fluid/framework/ir/conv_relu_mkldnn_fuse_pass_tester.cc

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,13 @@ TEST(ConvReLUFusePass, basic) {
8585

8686
for (auto* node : graph->Nodes()) {
8787
if (node->IsOp() && node->Op()->Type() == "conv2d") {
88-
if (node->Op()->HasAttr("use_mkldnn")) {
89-
bool use_mkldnn = boost::get<bool>(node->Op()->GetAttr("use_mkldnn"));
90-
if (use_mkldnn) {
91-
if (node->Op()->HasAttr("fuse_relu")) {
92-
bool fuse_relu = boost::get<bool>(node->Op()->GetAttr("fuse_relu"));
93-
if (fuse_relu) {
94-
++conv_relu_count;
95-
}
96-
}
97-
}
88+
auto* op = node->Op();
89+
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
90+
EXPECT_TRUE(boost::get<bool>(op->GetAttr("use_mkldnn")));
91+
ASSERT_TRUE(op->HasAttr("fuse_relu"));
92+
bool fuse_relu = boost::get<bool>(op->GetAttr("fuse_relu"));
93+
if (fuse_relu) {
94+
++conv_relu_count;
9895
}
9996
}
10097
}

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -638,11 +638,6 @@ PDNode *patterns::ConvReLU::operator()(
638638
->AsInput()
639639
->assert_is_persistable_var()
640640
->assert_is_op_input("conv2d", "Filter");
641-
// Bias
642-
auto *conv_bias_var = pattern->NewNode(conv_bias_repr())
643-
->AsInput()
644-
->assert_is_persistable_var()
645-
->assert_is_op_input("conv2d", "Bias");
646641
// intermediate variable, will be removed in the IR after fuse.
647642
auto *conv_out_var = pattern->NewNode(conv_out_repr())
648643
->AsIntermediate()
@@ -653,8 +648,7 @@ PDNode *patterns::ConvReLU::operator()(
653648
->AsOutput()
654649
->assert_is_op_output("relu");
655650

656-
conv_op->LinksFrom({conv_input, conv_weight_var, conv_bias_var})
657-
.LinksTo({conv_out_var});
651+
conv_op->LinksFrom({conv_input, conv_weight_var}).LinksTo({conv_out_var});
658652
relu_op->LinksFrom({conv_out_var}).LinksTo({relu_out_var});
659653
return relu_out_var;
660654
}

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ struct PatternBase {
379379
// op: conv + relu
380380
// named nodes:
381381
// conv_input, conv_weight,
382-
// conv_bias, conv_out, conv,
382+
// conv_out, conv,
383383
// relu_out, relu
384384
struct ConvReLU : public PatternBase {
385385
ConvReLU(PDPattern* pattern, const std::string& name_scope)
@@ -392,7 +392,6 @@ struct ConvReLU : public PatternBase {
392392
PATTERN_DECL_NODE(relu);
393393
// declare variable node's name
394394
PATTERN_DECL_NODE(conv_weight);
395-
PATTERN_DECL_NODE(conv_bias);
396395
PATTERN_DECL_NODE(conv_out);
397396
PATTERN_DECL_NODE(relu_out);
398397
};

0 commit comments

Comments
 (0)