@@ -26,8 +26,6 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
26
26
PADDLE_ENFORCE (graph.get ());
27
27
FusePassBase::Init (" conv_relu_mkldnn_fuse" , graph.get ());
28
28
29
- std::unordered_set<Node*> nodes2delete;
30
-
31
29
GraphPatternDetector gpd;
32
30
auto * conv_input = gpd.mutable_pattern ()
33
31
->NewNode (" conv_relu_mkldnn_fuse/conv_input" )
@@ -42,36 +40,20 @@ std::unique_ptr<ir::Graph> ConvReLUFusePass::ApplyImpl(
42
40
Graph* g) {
43
41
VLOG (4 ) << " handle ConvReLU fuse" ;
44
42
GET_IR_NODE_FROM_SUBGRAPH (conv_weight, conv_weight,
45
- conv_relu_pattern); // Filter
46
- GET_IR_NODE_FROM_SUBGRAPH (conv_bias, conv_bias, conv_relu_pattern); // Bias
47
- GET_IR_NODE_FROM_SUBGRAPH (conv_out, conv_out, conv_relu_pattern); // tmp
43
+ conv_relu_pattern); // Filter
44
+ GET_IR_NODE_FROM_SUBGRAPH (conv_out, conv_out, conv_relu_pattern); // tmp
48
45
GET_IR_NODE_FROM_SUBGRAPH (conv, conv, conv_relu_pattern); // CONV op
49
46
GET_IR_NODE_FROM_SUBGRAPH (relu_out, relu_out, conv_relu_pattern); // Out
50
47
GET_IR_NODE_FROM_SUBGRAPH (relu, relu, conv_relu_pattern); // ReLU op
51
48
52
- // Create an ConvReLU Node.
53
- OpDesc desc;
54
- std::string conv_relu_i_in = subgraph.at (conv_input)->Name ();
55
- std::string conv_relu_w_in = conv_weight->Name ();
56
- std::string conv_relu_b_in = conv_bias->Name ();
57
- std::string conv_relu_out = relu_out->Name ();
58
- desc.SetInput (" Input" , std::vector<std::string>({conv_relu_i_in}));
59
- desc.SetInput (" Filter" , std::vector<std::string>({conv_relu_w_in}));
60
- desc.SetInput (" Bias" , std::vector<std::string>({conv_relu_b_in}));
61
- desc.SetOutput (" Output" , std::vector<std::string>({conv_relu_out}));
62
- desc.SetType (" conv2d" );
63
- for (auto & attr : conv->Op ()->GetAttrMap ()) {
64
- desc.SetAttr (attr.first , attr.second );
65
- }
66
- desc.SetAttr (" fuse_relu" , true );
67
- auto conv_relu_node = g->CreateOpNode (&desc); // OpDesc will be copied.
68
- GraphSafeRemoveNodes (graph.get (), {conv, relu, conv_out});
49
+ // Transform Conv node into ConvReLU node.
50
+ OpDesc* desc = conv->Op ();
51
+ desc->SetOutput (" Output" , std::vector<std::string>({relu_out->Name ()}));
52
+ desc->SetAttr (" fuse_relu" , true );
53
+ GraphSafeRemoveNodes (graph.get (), {relu, conv_out});
69
54
70
55
PADDLE_ENFORCE (subgraph.count (conv_input));
71
- IR_NODE_LINK_TO (subgraph.at (conv_input), conv_relu_node);
72
- IR_NODE_LINK_TO (conv_weight, conv_relu_node);
73
- IR_NODE_LINK_TO (conv_bias, conv_relu_node);
74
- IR_NODE_LINK_TO (conv_relu_node, relu_out);
56
+ IR_NODE_LINK_TO (conv, relu_out);
75
57
76
58
found_conv_relu_count++;
77
59
};
0 commit comments