15
15
#include " mkldnn.hpp"
16
16
#include " paddle/fluid/operators/activation_op.h"
17
17
#include " paddle/fluid/operators/mkldnn_activation_op.h"
18
+ #include " paddle/fluid/platform/mkldnn_helper.h"
18
19
19
20
namespace paddle {
20
21
namespace operators {
@@ -23,6 +24,18 @@ using paddle::framework::Tensor;
23
24
using paddle::platform::MKLDNNDeviceContext;
24
25
25
26
namespace {
27
+ std::string gethash (const mkldnn::memory::dims &operand_dims,
28
+ const mkldnn::algorithm algorithm) {
29
+ auto dim2str = [](const mkldnn::memory::dims &operand_dims) {
30
+ std::string dstr = " " ;
31
+ for (size_t i = 0 ; i < operand_dims.size (); ++i) {
32
+ dstr += std::to_string (operand_dims[i]) + " -" ;
33
+ }
34
+ return dstr;
35
+ };
36
+ return dim2str (operand_dims) + std::to_string (algorithm);
37
+ }
38
+
26
39
template <typename T, typename ExecContext>
27
40
void eltwise_forward (const ExecContext &ctx, mkldnn::algorithm algorithm,
28
41
const T alpha = 0 , const T beta = 0 ) {
@@ -37,42 +50,70 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
37
50
const auto *src_data = src->template data <T>();
38
51
39
52
auto *dst = ctx.template Output <Tensor>(" Out" );
40
- const T *dst_data = dst->template mutable_data <T>(ctx.GetPlace ());
53
+ T *dst_data = dst->template mutable_data <T>(ctx.GetPlace ());
41
54
42
55
// get memory dim
43
56
PADDLE_ENFORCE (src->dims ().size () == 2 || src->dims ().size () == 4 ,
44
57
" Input dim must be with 2 or 4" );
45
58
std::vector<int > src_tz = framework::vectorize2int (src->dims ());
46
59
47
- // create memory description
48
- auto data_md = src_tz.size () == 2
49
- ? platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
50
- mkldnn::memory::format::nc)
51
- : platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
52
- mkldnn::memory::format::nchw);
53
-
54
- // create memory primitives
55
- auto src_memory =
56
- mkldnn::memory ({data_md, mkldnn_engine},
57
- static_cast <void *>(const_cast <float *>(src_data)));
58
- auto dst_memory =
59
- mkldnn::memory ({data_md, mkldnn_engine},
60
- static_cast <void *>(const_cast <float *>(dst_data)));
61
-
62
- auto forward_desc = mkldnn::eltwise_forward::desc (
63
- mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
64
-
65
- // save prim desc into global device context to be referred in backward path
66
- const std::string key = ctx.op ().Output (" Out" );
67
- const std::string key_eltwise_pd = key + " @eltwise_pd" ;
68
- auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
69
- forward_desc, mkldnn_engine);
70
- dev_ctx.SetBlob (key_eltwise_pd, forward_pd);
71
-
72
- auto eltwise = mkldnn::eltwise_forward (*forward_pd, src_memory, dst_memory);
60
+ const std::string key = gethash (src_tz, algorithm);
61
+ const std::string key_src_data =
62
+ key + ctx.op ().Output (" Out" ) + " @eltwise_fwd_src_data" ;
63
+ const std::string key_src_mem = key + " @eltwise_fwd_src_mem" ;
64
+ const std::string key_dst_mem = key + " @eltwise_fwd_dst_mem" ;
65
+ const std::string key_fwd = key + " @eltwise_fwd" ;
66
+
67
+ auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
68
+ dev_ctx.GetBlob (key_fwd));
69
+
70
+ // save input data to be referred in backward path
71
+ auto p_src_data = std::make_shared<const T *>(src_data);
72
+ dev_ctx.SetBlob (key_src_data, p_src_data);
73
+
74
+ if (p_fwd == nullptr ) {
75
+ // create memory description
76
+ auto data_md = src_tz.size () == 2
77
+ ? platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
78
+ mkldnn::memory::format::nc)
79
+ : platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
80
+ mkldnn::memory::format::nchw);
81
+
82
+ // create memory primitives
83
+ auto p_src_mem = std::make_shared<mkldnn::memory>(mkldnn::memory (
84
+ {data_md, mkldnn_engine}, platform::to_void_cast (src_data)));
85
+ dev_ctx.SetBlob (key_src_mem, p_src_mem);
86
+
87
+ auto p_dst_mem = std::make_shared<mkldnn::memory>(mkldnn::memory (
88
+ {data_md, mkldnn_engine}, platform::to_void_cast (dst_data)));
89
+ dev_ctx.SetBlob (key_dst_mem, p_dst_mem);
90
+
91
+ auto fwd_desc = mkldnn::eltwise_forward::desc (
92
+ mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
93
+ auto p_fwd_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
94
+ fwd_desc, mkldnn_engine);
95
+ const std::string key_fwd_pd = key + " eltwise_fwd_pd" ;
96
+ dev_ctx.SetBlob (key_fwd_pd, p_fwd_pd);
97
+ p_fwd = std::make_shared<mkldnn::eltwise_forward>(
98
+ *p_fwd_pd, *(p_src_mem.get ()), *(p_dst_mem.get ()));
99
+ dev_ctx.SetBlob (key_fwd, p_fwd);
100
+ } else {
101
+ // primitives already exist
102
+ auto p_src_mem =
103
+ std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob (key_src_mem));
104
+ PADDLE_ENFORCE (p_src_mem != nullptr ,
105
+ " Fail to find eltwise p_src_mem in device context." );
106
+ auto p_dst_mem =
107
+ std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob (key_dst_mem));
108
+ PADDLE_ENFORCE (p_dst_mem != nullptr ,
109
+ " Fail to find eltwise p_src_mem in device context." );
110
+
111
+ p_src_mem->set_data_handle (platform::to_void_reinterpret_cast (src_data));
112
+ p_dst_mem->set_data_handle (dst_data);
113
+ }
73
114
74
115
// push primitive to stream and wait until it's executed
75
- std::vector<mkldnn::primitive> pipeline = {eltwise };
116
+ std::vector<mkldnn::primitive> pipeline = {*(p_fwd. get ()) };
76
117
mkldnn::stream (mkldnn::stream::kind::eager).submit (pipeline).wait ();
77
118
}
78
119
@@ -83,8 +124,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
83
124
const auto &mkldnn_engine = dev_ctx.GetEngine ();
84
125
85
126
// get buffers
86
- const auto *x = ctx.template Input <Tensor>(" X" );
87
- const auto *src = x->template data <T>();
127
+ const auto *out = ctx.template Input <Tensor>(" Out" );
88
128
89
129
auto *dout = ctx.template Input <Tensor>(framework::GradVarName (" Out" ));
90
130
const auto *diff_dst = dout->template data <T>();
@@ -94,45 +134,73 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
94
134
const T *diff_src = dx->template mutable_data <T>(ctx.GetPlace ());
95
135
96
136
// get memory dim
97
- std::vector<int > src_tz = framework::vectorize2int (x->dims ());
98
-
99
- // create memory description
100
- auto data_md = src_tz.size () == 2
101
- ? platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
102
- mkldnn::memory::format::nc)
103
- : platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
104
- mkldnn::memory::format::nchw);
105
-
106
- // create memory primitives
107
- auto src_memory = mkldnn::memory (
108
- {data_md, mkldnn_engine}, static_cast <void *>(const_cast <float *>(src)));
109
- auto diff_src_memory =
110
- mkldnn::memory ({data_md, mkldnn_engine},
111
- static_cast <void *>(const_cast <float *>(diff_src)));
112
- auto diff_dst_memory =
113
- mkldnn::memory ({data_md, mkldnn_engine},
114
- static_cast <void *>(const_cast <float *>(diff_dst)));
115
-
116
- auto backward_desc =
117
- mkldnn::eltwise_backward::desc (algorithm, data_md, data_md, alpha, beta);
118
-
119
- // retrieve eltwise primitive desc from device context
120
- const std::string key = ctx.op ().Input (" Out" );
121
- const std::string key_eltwise_pd = key + " @eltwise_pd" ;
122
- const std::shared_ptr<void > forward_pd = dev_ctx.GetBlob (key_eltwise_pd);
123
- PADDLE_ENFORCE (forward_pd != nullptr ,
124
- " Fail to find eltwise_pd in device context" );
125
- auto *p_forward_pd =
126
- static_cast <mkldnn::eltwise_forward::primitive_desc *>(forward_pd.get ());
127
-
128
- auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc (
129
- backward_desc, mkldnn_engine, *p_forward_pd);
130
-
131
- auto eltwise_bwd = mkldnn::eltwise_backward (eltwise_bwd_prim_desc, src_memory,
132
- diff_dst_memory, diff_src_memory);
137
+ std::vector<int > src_tz = framework::vectorize2int (out->dims ());
138
+
139
+ const std::string key = gethash (src_tz, algorithm);
140
+ const std::string key_diff_src_mem = key + " @eltwise_diff_src_mem" ;
141
+ const std::string key_diff_dst_mem = key + " @eltwise_diff_dst_mem" ;
142
+ const std::string key_grad = key + " @eltwise_grad" ;
143
+
144
+ const std::string key_src_data =
145
+ key + ctx.op ().Input (" Out" ) + " @eltwise_fwd_src_data" ;
146
+ const auto p_src_data =
147
+ std::static_pointer_cast<T *>(dev_ctx.GetBlob (key_src_data));
148
+
149
+ const std::string key_src_mem = key + " @eltwise_fwd_src_mem" ;
150
+ auto p_src_mem =
151
+ std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob (key_src_mem));
152
+ p_src_mem->set_data_handle (*p_src_data.get ());
153
+
154
+ auto p_grad = std::static_pointer_cast<mkldnn::eltwise_forward::primitive>(
155
+ dev_ctx.GetBlob (key_grad));
156
+
157
+ if (p_grad == nullptr ) {
158
+ // create memory description
159
+ auto data_md = src_tz.size () == 2
160
+ ? platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
161
+ mkldnn::memory::format::nc)
162
+ : platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
163
+ mkldnn::memory::format::nchw);
164
+
165
+ // create memory primitives
166
+ std::shared_ptr<void > p_diff_src_mem =
167
+ std::make_shared<mkldnn::memory>(mkldnn::memory (
168
+ {data_md, mkldnn_engine}, platform::to_void_cast (diff_src)));
169
+ dev_ctx.SetBlob (key_diff_src_mem, p_diff_src_mem);
170
+ std::shared_ptr<void > p_diff_dst_mem =
171
+ std::make_shared<mkldnn::memory>(mkldnn::memory (
172
+ {data_md, mkldnn_engine}, platform::to_void_cast (diff_dst)));
173
+ dev_ctx.SetBlob (key_diff_dst_mem, p_diff_dst_mem);
174
+
175
+ auto bwd_desc = mkldnn::eltwise_backward::desc (algorithm, data_md, data_md,
176
+ alpha, beta);
177
+
178
+ const std::string key_fwd_pd = key + " eltwise_fwd_pd" ;
179
+ auto *p_fwd_pd = static_cast <mkldnn::eltwise_forward::primitive_desc *>(
180
+ dev_ctx.GetBlob (key_fwd_pd).get ());
181
+
182
+ auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc (
183
+ bwd_desc, mkldnn_engine, *p_fwd_pd);
184
+
185
+ p_grad = std::make_shared<mkldnn::eltwise_backward>(
186
+ eltwise_bwd_prim_desc, *static_cast <mkldnn::memory *>(p_src_mem.get ()),
187
+ *(static_cast <mkldnn::memory *>(p_diff_dst_mem.get ())),
188
+ *(static_cast <mkldnn::memory *>(p_diff_src_mem.get ())));
189
+ } else {
190
+ // primitives already exist
191
+ auto p_diff_src_mem = std::static_pointer_cast<mkldnn::memory>(
192
+ dev_ctx.GetBlob (key_diff_src_mem));
193
+ auto p_diff_dst_mem = std::static_pointer_cast<mkldnn::memory>(
194
+ dev_ctx.GetBlob (key_diff_dst_mem));
195
+
196
+ p_diff_src_mem->set_data_handle (
197
+ platform::to_void_reinterpret_cast (diff_src));
198
+ p_diff_dst_mem->set_data_handle (
199
+ platform::to_void_reinterpret_cast (diff_dst));
200
+ }
133
201
134
202
// push primitive to stream and wait until it's executed
135
- std::vector<mkldnn::primitive> pipeline = {eltwise_bwd };
203
+ std::vector<mkldnn::primitive> pipeline = {*(p_grad. get ()) };
136
204
mkldnn::stream (mkldnn::stream::kind::eager).submit (pipeline).wait ();
137
205
}
138
206
} // anonymous namespace
0 commit comments