11
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
+
14
15
#include " paddle/fluid/framework/ir/conv_bias_mkldnn_fuse_pass.h"
16
+ #include < functional>
15
17
#include < string>
16
18
#include < vector>
19
+ #include " paddle/fluid/framework/lod_tensor.h"
17
20
#include " paddle/fluid/platform/enforce.h"
21
+
18
22
namespace paddle {
19
23
namespace framework {
20
24
namespace ir {
25
+
26
+ template <typename BinaryOperation>
27
+ LoDTensor tensor_apply_eltwise (const LoDTensor& vec_a, const LoDTensor& vec_b,
28
+ BinaryOperation f) {
29
+ PADDLE_ENFORCE_EQ (vec_a.dims (), vec_b.dims ());
30
+ LoDTensor vec_y;
31
+ vec_y.Resize (vec_a.dims ());
32
+ const float * a = vec_a.data <float >();
33
+ const float * b = vec_b.data <float >();
34
+ float * y = vec_y.mutable_data <float >(platform::CPUPlace ());
35
+ for (int i = 0 ; i < vec_a.numel (); i++) {
36
+ y[i] = f (a[i], b[i]);
37
+ }
38
+ return vec_y;
39
+ }
40
+
21
41
std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl (
22
42
std::unique_ptr<ir::Graph> graph) const {
23
43
PADDLE_ENFORCE (graph.get ());
24
- FusePassBase::Init (" conv_bias_mkldnn_fuse" , graph.get ());
44
+ FusePassBase::Init (name_scope_, graph.get ());
45
+
46
+ auto * scope = param_scope ();
47
+ PADDLE_ENFORCE (scope);
48
+
25
49
GraphPatternDetector gpd;
26
- auto * conv_input = gpd. mutable_pattern ()
27
- -> NewNode ( " conv_bias_mkldnn_fuse/conv_input " )
28
- -> AsInput ( )
29
- -> assert_is_op_input ( " conv2d " , " Input " );
30
- patterns::ConvBias conv_bias_pattern (gpd. mutable_pattern (),
31
- " conv_bias_mkldnn_fuse " );
50
+ auto * conv_input =
51
+ gpd. mutable_pattern ( )
52
+ -> NewNode ( patterns::PDNodeName (name_scope_, " conv_input " ) )
53
+ -> AsInput ()
54
+ -> assert_is_op_input ( " conv2d " , " Input " );
55
+ patterns::ConvBias conv_bias_pattern (gpd. mutable_pattern (), name_scope_ );
32
56
conv_bias_pattern (conv_input);
33
57
int found_conv_bias_count = 0 ;
34
58
auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
@@ -44,27 +68,55 @@ std::unique_ptr<ir::Graph> ConvBiasFusePass::ApplyImpl(
44
68
GET_IR_NODE_FROM_SUBGRAPH (eltwise_out, eltwise_out, conv_bias_pattern);
45
69
// elementwise_add op
46
70
GET_IR_NODE_FROM_SUBGRAPH (eltwise, eltwise, conv_bias_pattern);
47
- // Create an ConvBias Node.
48
- OpDesc desc;
49
- std::string conv_bias_i_in = subgraph.at (conv_input)->Name ();
50
- std::string conv_bias_w_in = conv_weight->Name ();
51
- std::string conv_bias_b_in = eltwise_bias->Name ();
52
- std::string conv_bias_out = eltwise_out->Name ();
53
- desc.SetInput (" Input" , std::vector<std::string>({conv_bias_i_in}));
54
- desc.SetInput (" Filter" , std::vector<std::string>({conv_bias_w_in}));
55
- desc.SetInput (" Bias" , std::vector<std::string>({conv_bias_b_in}));
56
- desc.SetOutput (" Output" , std::vector<std::string>({conv_bias_out}));
57
- desc.SetType (" conv2d" );
58
- for (auto & attr : conv->Op ()->GetAttrMap ()) {
59
- desc.SetAttr (attr.first , attr.second );
60
- }
61
- auto conv_bias_node = g->CreateOpNode (&desc); // OpDesc will be copied.
62
- GraphSafeRemoveNodes (graph.get (), {conv, eltwise, conv_out});
71
+
63
72
PADDLE_ENFORCE (subgraph.count (conv_input));
64
- IR_NODE_LINK_TO (subgraph.at (conv_input), conv_bias_node);
65
- IR_NODE_LINK_TO (conv_weight, conv_bias_node);
66
- IR_NODE_LINK_TO (eltwise_bias, conv_bias_node);
67
- IR_NODE_LINK_TO (conv_bias_node, eltwise_out);
73
+
74
+ auto * eltwise_bias_tensor =
75
+ scope->FindVar (eltwise_bias->Name ())->GetMutable <LoDTensor>();
76
+
77
+ auto input_names = conv->Op ()->InputNames ();
78
+ bool has_bias = std::find (input_names.begin (), input_names.end (), " Bias" ) !=
79
+ input_names.end ();
80
+ if (has_bias && conv->Op ()->Input (" Bias" ).size () > 0 ) {
81
+ auto conv_bias_names = conv->Op ()->Input (" Bias" );
82
+ // add eltwise bias to existing conv bias
83
+ PADDLE_ENFORCE_EQ (conv_bias_names.size (), 1 );
84
+ auto * conv_bias_var = scope->FindVar (conv_bias_names[0 ]);
85
+ auto * conv_bias_tensor = conv_bias_var->GetMutable <LoDTensor>();
86
+ PADDLE_ENFORCE_EQ (conv_bias_tensor->dims (), eltwise_bias_tensor->dims ());
87
+ *conv_bias_tensor = tensor_apply_eltwise (
88
+ *conv_bias_tensor, *eltwise_bias_tensor, std::plus<float >());
89
+
90
+ conv->Op ()->SetOutput (" Output" ,
91
+ std::vector<std::string>({eltwise_out->Name ()}));
92
+
93
+ GraphSafeRemoveNodes (graph.get (), {eltwise, conv_out});
94
+
95
+ IR_NODE_LINK_TO (conv, eltwise_out);
96
+ } else {
97
+ // take eltwise bias as conv bias
98
+ OpDesc desc;
99
+
100
+ desc.SetInput (
101
+ " Input" , std::vector<std::string>({subgraph.at (conv_input)->Name ()}));
102
+ desc.SetInput (" Filter" , std::vector<std::string>({conv_weight->Name ()}));
103
+ desc.SetInput (" Bias" , std::vector<std::string>({eltwise_bias->Name ()}));
104
+ desc.SetOutput (" Output" , std::vector<std::string>({eltwise_out->Name ()}));
105
+ desc.SetType (" conv2d" );
106
+
107
+ for (auto & attr : conv->Op ()->GetAttrMap ()) {
108
+ desc.SetAttr (attr.first , attr.second );
109
+ }
110
+ auto conv_bias_node = g->CreateOpNode (&desc);
111
+
112
+ IR_NODE_LINK_TO (subgraph.at (conv_input), conv_bias_node);
113
+ IR_NODE_LINK_TO (conv_weight, conv_bias_node);
114
+ IR_NODE_LINK_TO (eltwise_bias, conv_bias_node);
115
+ IR_NODE_LINK_TO (conv_bias_node, eltwise_out);
116
+
117
+ GraphSafeRemoveNodes (graph.get (), {conv, eltwise, conv_out});
118
+ }
119
+
68
120
found_conv_bias_count++;
69
121
};
70
122
gpd (graph.get (), handler);
0 commit comments