@@ -18,6 +18,26 @@ limitations under the License. */
18
18
namespace paddle {
19
19
namespace operators {
20
20
21
+ using mkldnn::memory; // Note: paddle has also "memory" namespace
22
+ using mkldnn::pooling_forward;
23
+ using mkldnn::pooling_backward;
24
+
25
+ // Generate keys for storing/retriving primitives for this operator
26
+ // TODO(jczaja): Make hashing function more optimial
27
+ static std::string gethash (memory::dims& input_dims, std::string& pooling_type,
28
+ std::vector<int >& ksize, std::vector<int >& strides,
29
+ std::vector<int >& paddings, std::string suffix) {
30
+ auto dims2str = [](memory::dims& operand_dims) {
31
+ std::string dstr = " " ;
32
+ for (size_t i = 0 ; i < operand_dims.size (); ++i) {
33
+ dstr += std::to_string (operand_dims[i]) + " -" ;
34
+ }
35
+ return dstr;
36
+ };
37
+ return dims2str (input_dims) + dims2str (ksize) + dims2str (strides) +
38
+ dims2str (paddings) + pooling_type + suffix;
39
+ }
40
+
21
41
template <typename T>
22
42
class PoolMKLDNNOpKernel : public paddle ::framework::OpKernel<T> {
23
43
public:
@@ -34,10 +54,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
34
54
35
55
// Get an unique name from "argument" name of "Out" variable
36
56
// This name will be used as key when saving info into device context
37
- const std::string key = ctx.op ().Output (" Out" );
38
- const std::string key_pool_pd = key + " @pool_pd" ;
39
- const std::string key_pool_workspace_memory =
40
- key + " @pool_workspace_memory" ;
41
57
42
58
std::string pooling_type = ctx.Attr <std::string>(" pooling_type" );
43
59
std::vector<int > ksize = ctx.Attr <std::vector<int >>(" ksize" );
@@ -63,37 +79,71 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
63
79
std::vector<int > src_tz = paddle::framework::vectorize2int (input->dims ());
64
80
std::vector<int > dst_tz = paddle::framework::vectorize2int (output->dims ());
65
81
66
- // TODO(pzelazko-intel): support more formats
67
- auto src_md = platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
68
- mkldnn::memory::format::nchw);
69
- auto dst_md = platform::MKLDNNMemDesc (dst_tz, mkldnn::memory::f32 ,
70
- mkldnn::memory::format::nchw);
71
-
72
- std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
73
- CreatePrimitiveDesc (src_md, dst_md, strides, paddings, ksize,
74
- pooling_type, mkldnn_engine);
75
-
76
- // save pool_pd into global device context to be referred in backward path
77
- dev_ctx.SetBlob (key_pool_pd, pool_pd);
78
-
79
- std::shared_ptr<mkldnn::memory> workspace_memory =
80
- CreateWorkspaceMemory (pool_pd, pooling_type, mkldnn_engine);
81
-
82
- // save pool_workspace_memory to be referred in backward path
83
- dev_ctx.SetBlob (key_pool_workspace_memory, workspace_memory);
84
-
85
- auto src_memory =
86
- mkldnn::memory ({src_md, mkldnn_engine},
87
- static_cast <void *>(const_cast <T*>(input_data)));
88
- auto dst_memory =
89
- mkldnn::memory ({dst_md, mkldnn_engine},
90
- static_cast <void *>(const_cast <T*>(output_data)));
82
+ const std::string key = gethash (src_tz, pooling_type, ksize, strides,
83
+ paddings, ctx.op ().Output (" Out" ));
84
+ const std::string key_pool_p = key + " @pool_p" ;
85
+ const std::string key_pool_pd = key + " @pool_pd" ;
86
+ const std::string key_pool_src_mem_p = key + " @pool_src_mem_p" ;
87
+ const std::string key_pool_dst_mem_p = key + " @pool_dst_mem_p" ;
88
+ const std::string key_pool_workspace_memory =
89
+ key + " @pool_workspace_memory" ;
91
90
92
- auto pool_prim = mkldnn::pooling_forward (*pool_pd, src_memory, dst_memory,
93
- *workspace_memory);
91
+ auto pool_p =
92
+ std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob (key_pool_p));
93
+ if (pool_p == nullptr ) {
94
+ // TODO(pzelazko-intel): support more formats
95
+
96
+ auto src_md =
97
+ platform::MKLDNNMemDesc (src_tz, platform::MKLDNNGetDataType<T>(),
98
+ mkldnn::memory::format::nchw);
99
+ auto dst_md =
100
+ platform::MKLDNNMemDesc (dst_tz, platform::MKLDNNGetDataType<T>(),
101
+ mkldnn::memory::format::nchw);
102
+
103
+ std::shared_ptr<pooling_forward::primitive_desc> pool_pd =
104
+ CreatePrimitiveDesc (src_md, dst_md, strides, paddings, ksize,
105
+ pooling_type, mkldnn_engine);
106
+
107
+ // save pool_pd into global device context to be referred in backward path
108
+ dev_ctx.SetBlob (key_pool_pd, pool_pd);
109
+
110
+ std::shared_ptr<mkldnn::memory> workspace_memory =
111
+ CreateWorkspaceMemory (pool_pd, pooling_type, mkldnn_engine);
112
+
113
+ // save pool_workspace_memory to be referred in backward path
114
+ dev_ctx.SetBlob (key_pool_workspace_memory, workspace_memory);
115
+
116
+ auto pool_src_memory_p = std::make_shared<memory>(
117
+ memory::primitive_desc{src_md, mkldnn_engine},
118
+ static_cast <void *>(const_cast <T*>(input_data)));
119
+ dev_ctx.SetBlob (key_pool_src_mem_p, pool_src_memory_p);
120
+
121
+ auto pool_dst_memory_p = std::make_shared<memory>(
122
+ memory::primitive_desc{dst_md, mkldnn_engine},
123
+ static_cast <void *>(output_data));
124
+ dev_ctx.SetBlob (key_pool_dst_mem_p, pool_dst_memory_p);
125
+
126
+ pool_p = std::make_shared<pooling_forward>(
127
+ *pool_pd, *(pool_src_memory_p.get ()), *(pool_dst_memory_p.get ()),
128
+ *workspace_memory);
129
+ dev_ctx.SetBlob (key_pool_p, pool_p);
130
+ } else {
131
+ // Primitives already exist
132
+ auto pool_src_memory_p =
133
+ std::static_pointer_cast<memory>(dev_ctx.GetBlob (key_pool_src_mem_p));
134
+ PADDLE_ENFORCE (pool_src_memory_p != nullptr ,
135
+ " Fail to find pooling src mem_p in device context" );
136
+ auto pool_dst_memory_p =
137
+ std::static_pointer_cast<memory>(dev_ctx.GetBlob (key_pool_dst_mem_p));
138
+ PADDLE_ENFORCE (pool_dst_memory_p != nullptr ,
139
+ " Fail to find pooling dst mem_p in device context" );
140
+ pool_src_memory_p->set_data_handle (
141
+ reinterpret_cast <void *>(const_cast <T*>(input_data)));
142
+ pool_dst_memory_p->set_data_handle (output_data);
143
+ }
94
144
95
145
// push primitive to stream and wait until it's executed
96
- std::vector<mkldnn::primitive> pipeline{pool_prim };
146
+ std::vector<mkldnn::primitive> pipeline{*(pool_p. get ()) };
97
147
mkldnn::stream (mkldnn::stream::kind::eager).submit (pipeline).wait ();
98
148
}
99
149
@@ -120,9 +170,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
120
170
mkldnn::memory::primitive_desc workspace_md =
121
171
pooling_type == " max"
122
172
? pool_pd->workspace_primitive_desc ()
123
- : mkldnn::memory::primitive_desc (
124
- {{}, mkldnn::memory::f32 , mkldnn::memory::format::nchw},
125
- engine);
173
+ : mkldnn::memory::primitive_desc ({{},
174
+ platform::MKLDNNGetDataType<T>(),
175
+ mkldnn::memory::format::nchw},
176
+ engine);
126
177
127
178
auto p_workspace_memory = new mkldnn::memory (workspace_md);
128
179
return std::unique_ptr<mkldnn::memory>(p_workspace_memory);
@@ -140,13 +191,6 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
140
191
const Tensor* out_grad = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
141
192
Tensor* in_x_grad = ctx.Output <Tensor>(framework::GradVarName (" X" ));
142
193
143
- // Get an unique name from "argument" name of "Out" variable
144
- // This name will be used as key when referring info from device context
145
- const std::string key = ctx.op ().Input (" Out" );
146
- const std::string key_pool_pd = key + " @pool_pd" ;
147
- const std::string key_pool_workspace_memory =
148
- key + " @pool_workspace_memory" ;
149
-
150
194
std::string pooling_type = ctx.Attr <std::string>(" pooling_type" );
151
195
std::vector<int > ksize = ctx.Attr <std::vector<int >>(" ksize" );
152
196
std::vector<int > strides = ctx.Attr <std::vector<int >>(" strides" );
@@ -171,43 +215,76 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
171
215
std::vector<int > diff_dst_tz =
172
216
paddle::framework::vectorize2int (out_grad->dims ());
173
217
174
- auto diff_src_md = platform::MKLDNNMemDesc (diff_src_tz, mkldnn::memory::f32 ,
175
- mkldnn::memory::format::nchw);
176
- auto diff_dst_md = platform::MKLDNNMemDesc (diff_dst_tz, mkldnn::memory::f32 ,
177
- mkldnn::memory::format::nchw);
178
-
179
- // Retrieve pool_pd/pool_workspace_memory from device context
180
- auto pool_pd =
181
- std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
182
- dev_ctx.GetBlob (key_pool_pd));
183
- PADDLE_ENFORCE (pool_pd != nullptr ,
184
- " Fail to find pool_pd in device context" );
185
-
186
- auto workspace_memory = std::static_pointer_cast<mkldnn::memory>(
187
- dev_ctx.GetBlob (key_pool_workspace_memory));
188
- PADDLE_ENFORCE (workspace_memory != nullptr ,
189
- " Fail to find workspace_memory in device context" );
190
-
191
- auto pool_bwd_desc = mkldnn::pooling_backward::desc (
192
- pooling_type == " max" ? mkldnn::algorithm::pooling_max
193
- : mkldnn::algorithm::pooling_avg,
194
- diff_src_md, diff_dst_md, strides, ksize, paddings, paddings,
195
- mkldnn::padding_kind::zero);
196
- auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc (
197
- pool_bwd_desc, mkldnn_engine, *pool_pd);
198
-
199
- auto diff_src_memory =
200
- mkldnn::memory ({diff_src_md, mkldnn_engine},
201
- static_cast <void *>(const_cast <T*>(in_x_grad_data)));
202
- auto diff_dst_memory =
203
- mkldnn::memory ({diff_dst_md, mkldnn_engine},
204
- static_cast <void *>(const_cast <T*>(out_grad_data)));
218
+ // Get an unique name from "argument" name of "Out" variable
219
+ // This name will be used as key when referring info from device context
220
+ const std::string key = gethash (diff_src_tz, pooling_type, ksize, strides,
221
+ paddings, ctx.op ().Input (" Out" ));
222
+ const std::string key_pool_bwd_p = key + " @pool_bwd_p" ;
223
+ const std::string key_pool_diff_src_mem_p = key + " @pool_diff_src_mem_p" ;
224
+ const std::string key_pool_diff_dst_mem_p = key + " @pool_diff_dst_mem_p" ;
225
+ const std::string key_pool_pd = key + " @pool_pd" ;
226
+ const std::string key_pool_workspace_memory =
227
+ key + " @pool_workspace_memory" ;
205
228
206
- auto bwd_prim = mkldnn::pooling_backward (
207
- pool_bwd_pd, diff_dst_memory, *workspace_memory, diff_src_memory);
229
+ auto pool_bwd_p = std::static_pointer_cast<pooling_backward>(
230
+ dev_ctx.GetBlob (key_pool_bwd_p));
231
+ if (pool_bwd_p == nullptr ) {
232
+ auto diff_src_md =
233
+ platform::MKLDNNMemDesc (diff_src_tz, platform::MKLDNNGetDataType<T>(),
234
+ mkldnn::memory::format::nchw);
235
+ auto diff_dst_md =
236
+ platform::MKLDNNMemDesc (diff_dst_tz, platform::MKLDNNGetDataType<T>(),
237
+ mkldnn::memory::format::nchw);
238
+ // Retrieve pool_pd/pool_workspace_memory from device context
239
+ auto pool_pd =
240
+ std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
241
+ dev_ctx.GetBlob (key_pool_pd));
242
+ PADDLE_ENFORCE (pool_pd != nullptr ,
243
+ " Fail to find pool_pd in device context" );
244
+
245
+ auto workspace_memory = std::static_pointer_cast<mkldnn::memory>(
246
+ dev_ctx.GetBlob (key_pool_workspace_memory));
247
+ PADDLE_ENFORCE (workspace_memory != nullptr ,
248
+ " Fail to find workspace_memory in device context" );
249
+
250
+ auto pool_diff_src_memory_p = std::make_shared<memory>(memory (
251
+ {diff_src_md, mkldnn_engine}, static_cast <void *>(in_x_grad_data)));
252
+ dev_ctx.SetBlob (key_pool_diff_src_mem_p, pool_diff_src_memory_p);
253
+
254
+ auto pool_diff_dst_memory_p = std::make_shared<memory>(
255
+ memory ({diff_dst_md, mkldnn_engine},
256
+ static_cast <void *>(const_cast <T*>(out_grad_data))));
257
+ dev_ctx.SetBlob (key_pool_diff_dst_mem_p, pool_diff_dst_memory_p);
258
+
259
+ auto pool_bwd_desc = mkldnn::pooling_backward::desc (
260
+ pooling_type == " max" ? mkldnn::algorithm::pooling_max
261
+ : mkldnn::algorithm::pooling_avg,
262
+ diff_src_md, diff_dst_md, strides, ksize, paddings, paddings,
263
+ mkldnn::padding_kind::zero);
264
+ auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc (
265
+ pool_bwd_desc, mkldnn_engine, *pool_pd);
266
+
267
+ pool_bwd_p = std::make_shared<pooling_backward>(
268
+ pool_bwd_pd, *(pool_diff_dst_memory_p.get ()), *workspace_memory,
269
+ *(pool_diff_src_memory_p));
270
+ dev_ctx.SetBlob (key_pool_bwd_p, pool_bwd_p);
271
+ } else {
272
+ // Primitives already exist
273
+ auto pool_diff_src_memory_p = std::static_pointer_cast<memory>(
274
+ dev_ctx.GetBlob (key_pool_diff_src_mem_p));
275
+ PADDLE_ENFORCE (pool_diff_src_memory_p != nullptr ,
276
+ " Fail to find pooling src mem_p in device context" );
277
+ auto pool_diff_dst_memory_p = std::static_pointer_cast<memory>(
278
+ dev_ctx.GetBlob (key_pool_diff_dst_mem_p));
279
+ PADDLE_ENFORCE (pool_diff_dst_memory_p != nullptr ,
280
+ " Fail to find pooling dst mem_p in device context" );
281
+ pool_diff_src_memory_p->set_data_handle (
282
+ reinterpret_cast <void *>(in_x_grad_data));
283
+ pool_diff_dst_memory_p->set_data_handle (const_cast <T*>(out_grad_data));
284
+ }
208
285
209
286
// push primitive to stream and wait until it's executed
210
- std::vector<mkldnn::primitive> pipeline{bwd_prim };
287
+ std::vector<mkldnn::primitive> pipeline{*(pool_bwd_p. get ()) };
211
288
mkldnn::stream (mkldnn::stream::kind::eager).submit (pipeline).wait ();
212
289
} // Compute()
213
290
};
0 commit comments