@@ -18,9 +18,14 @@ limitations under the License. */
18
18
namespace paddle {
19
19
namespace operators {
20
20
21
- using mkldnn::memory; // Note: paddle has also "memory" namespace
22
- using mkldnn::pooling_forward ;
21
+ using framework::DataLayout;
22
+ using mkldnn::memory ;
23
23
using mkldnn::pooling_backward;
24
+ using mkldnn::pooling_forward;
25
+ using mkldnn::primitive;
26
+ using mkldnn::reorder;
27
+ using mkldnn::stream;
28
+ using platform::to_void_cast;
24
29
25
30
// Generate keys for storing/retriving primitives for this operator
26
31
// TODO(jczaja): Make hashing function more optimial
@@ -55,8 +60,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
55
60
const Tensor* input = ctx.Input <Tensor>(" X" );
56
61
Tensor* output = ctx.Output <Tensor>(" Out" );
57
62
58
- // Get an unique name from "argument" name of "Out" variable
59
- // This name will be used as key when saving info into device context
63
+ PADDLE_ENFORCE (input->layout () == DataLayout::kMKLDNN &&
64
+ input->format () != memory::format::format_undef,
65
+ " Wrong layout/format set for Input tensor" );
60
66
61
67
std::string pooling_type = ctx.Attr <std::string>(" pooling_type" );
62
68
std::vector<int > ksize = ctx.Attr <std::vector<int >>(" ksize" );
@@ -82,6 +88,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
82
88
std::vector<int > src_tz = paddle::framework::vectorize2int (input->dims ());
83
89
std::vector<int > dst_tz = paddle::framework::vectorize2int (output->dims ());
84
90
91
+ auto input_format = input->format ();
92
+ memory::format output_format{memory::format::format_undef};
93
+
85
94
const std::string key = gethash (src_tz, pooling_type, ksize, strides,
86
95
paddings, ctx.op ().Output (" Out" ));
87
96
const std::string key_pool_p = key + " @pool_p" ;
@@ -94,16 +103,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
94
103
auto pool_p =
95
104
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob (key_pool_p));
96
105
if (pool_p == nullptr ) {
97
- // TODO(pzelazko-intel): support more formats
106
+ auto src_md = platform::MKLDNNMemDesc (
107
+ src_tz, platform::MKLDNNGetDataType<T>(), input_format);
98
108
99
- auto src_md =
100
- platform::MKLDNNMemDesc (src_tz, platform::MKLDNNGetDataType<T>(),
101
- mkldnn::memory::format::nchw);
102
- auto dst_md =
103
- platform::MKLDNNMemDesc (dst_tz, platform::MKLDNNGetDataType<T>() ,
104
- mkldnn::memory::format::nchw );
109
+ /* create memory descriptor for pooling without specified format
110
+ * ('any') which lets a primitive (pooling in this case) choose
111
+ * the memory format preferred for best performance
112
+ */
113
+ auto dst_md = platform::MKLDNNMemDesc (dst_tz, mkldnn::memory:: f32 ,
114
+ mkldnn::memory::format::any );
105
115
106
- std::shared_ptr<pooling_forward::primitive_desc> pool_pd =
116
+ std::shared_ptr<mkldnn:: pooling_forward::primitive_desc> pool_pd =
107
117
CreatePrimitiveDesc (src_md, dst_md, strides, paddings, ksize,
108
118
pooling_type, mkldnn_engine);
109
119
@@ -116,20 +126,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
116
126
// save pool_workspace_memory to be referred in backward path
117
127
dev_ctx.SetBlob (key_pool_workspace_memory, workspace_memory);
118
128
119
- auto pool_src_memory_p = std::make_shared<memory>(
120
- memory::primitive_desc{src_md, mkldnn_engine},
121
- static_cast < void *>( const_cast <T*>(input_data)));
122
- dev_ctx. SetBlob (key_pool_src_mem_p, pool_src_memory_p );
129
+ auto src_memory = std::make_shared<memory>(pool_pd-> src_primitive_desc (),
130
+ to_void_cast<T>(input_data));
131
+ auto dst_memory =
132
+ std::make_shared<memory>(pool_pd-> dst_primitive_desc (), output_data );
123
133
124
- auto pool_dst_memory_p = std::make_shared<memory>(
125
- memory::primitive_desc{dst_md, mkldnn_engine},
126
- static_cast <void *>(output_data));
127
- dev_ctx.SetBlob (key_pool_dst_mem_p, pool_dst_memory_p);
134
+ dev_ctx.SetBlob (key_pool_src_mem_p, src_memory);
135
+ dev_ctx.SetBlob (key_pool_dst_mem_p, dst_memory);
136
+
137
+ pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get ()),
138
+ *(dst_memory.get ()),
139
+ *workspace_memory);
128
140
129
- pool_p = std::make_shared<pooling_forward>(
130
- *pool_pd, *(pool_src_memory_p.get ()), *(pool_dst_memory_p.get ()),
131
- *workspace_memory);
132
141
dev_ctx.SetBlob (key_pool_p, pool_p);
142
+
143
+ output_format =
144
+ (memory::format)dst_memory->get_primitive_desc ().desc ().data .format ;
133
145
} else {
134
146
// Primitives already exist
135
147
auto pool_src_memory_p =
@@ -140,14 +152,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
140
152
std::static_pointer_cast<memory>(dev_ctx.GetBlob (key_pool_dst_mem_p));
141
153
PADDLE_ENFORCE (pool_dst_memory_p != nullptr ,
142
154
" Fail to find pooling dst mem_p in device context" );
143
- pool_src_memory_p->set_data_handle (
144
- reinterpret_cast <void *>(const_cast <T*>(input_data)));
155
+ pool_src_memory_p->set_data_handle (to_void_cast<T>(input_data));
145
156
pool_dst_memory_p->set_data_handle (output_data);
157
+
158
+ output_format = (memory::format)pool_dst_memory_p->get_primitive_desc ()
159
+ .desc ()
160
+ .data .format ;
146
161
}
147
162
148
163
// push primitive to stream and wait until it's executed
149
164
std::vector<mkldnn::primitive> pipeline{*(pool_p.get ())};
150
- mkldnn::stream (mkldnn::stream::kind::eager).submit (pipeline).wait ();
165
+ stream (stream::kind::eager).submit (pipeline).wait ();
166
+
167
+ output->set_layout (DataLayout::kMKLDNN );
168
+ output->set_format (output_format);
151
169
}
152
170
153
171
private:
@@ -194,6 +212,13 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
194
212
const Tensor* out_grad = ctx.Input <Tensor>(framework::GradVarName (" Out" ));
195
213
Tensor* in_x_grad = ctx.Output <Tensor>(framework::GradVarName (" X" ));
196
214
215
+ PADDLE_ENFORCE (in_x->layout () == DataLayout::kMKLDNN &&
216
+ in_x->format () != memory::format::format_undef,
217
+ " Wrong layout/format set for Input X tensor" );
218
+ PADDLE_ENFORCE (out_grad->layout () == DataLayout::kMKLDNN &&
219
+ out_grad->format () != memory::format::format_undef,
220
+ " Wrong layout/format set for Input output_grad tensor" );
221
+
197
222
std::string pooling_type = ctx.Attr <std::string>(" pooling_type" );
198
223
std::vector<int > ksize = ctx.Attr <std::vector<int >>(" ksize" );
199
224
std::vector<int > strides = ctx.Attr <std::vector<int >>(" strides" );
@@ -212,6 +237,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
212
237
213
238
const T* out_grad_data = out_grad->data <T>();
214
239
T* in_x_grad_data = in_x_grad->mutable_data <T>(ctx.GetPlace ());
240
+ memory::format in_x_grad_format{memory::format::format_undef};
215
241
216
242
std::vector<int > diff_src_tz =
217
243
paddle::framework::vectorize2int (in_x_grad->dims ());
@@ -225,39 +251,48 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
225
251
const std::string key_pool_bwd_p = key + " @pool_bwd_p" ;
226
252
const std::string key_pool_diff_src_mem_p = key + " @pool_diff_src_mem_p" ;
227
253
const std::string key_pool_diff_dst_mem_p = key + " @pool_diff_dst_mem_p" ;
254
+ const std::string key_pool_src_mem_p = key + " @pool_src_mem_p" ;
255
+ const std::string key_pool_dst_mem_p = key + " @pool_dst_mem_p" ;
228
256
const std::string key_pool_pd = key + " @pool_pd" ;
229
257
const std::string key_pool_workspace_memory =
230
258
key + " @pool_workspace_memory" ;
231
259
260
+ auto user_diff_dst_memory =
261
+ memory ({{{diff_dst_tz}, memory::data_type::f32 , out_grad->format ()},
262
+ mkldnn_engine},
263
+ to_void_cast<T>(out_grad_data));
264
+
265
+ std::shared_ptr<memory> diff_src_memory;
266
+ std::shared_ptr<memory> diff_dst_memory;
267
+ auto dst_memory =
268
+ std::static_pointer_cast<memory>(dev_ctx.GetBlob (key_pool_dst_mem_p));
269
+ PADDLE_ENFORCE (dst_memory != nullptr ,
270
+ " Fail to find dst_memory in device context" );
271
+
272
+ primitive reorder_diff_dst;
273
+ bool is_diff_dst_reordered = false ;
232
274
auto pool_bwd_p = std::static_pointer_cast<pooling_backward>(
233
275
dev_ctx.GetBlob (key_pool_bwd_p));
234
276
if (pool_bwd_p == nullptr ) {
235
- auto diff_src_md =
236
- platform::MKLDNNMemDesc (diff_src_tz, platform::MKLDNNGetDataType<T>(),
237
- mkldnn::memory::format::nchw);
238
- auto diff_dst_md =
239
- platform::MKLDNNMemDesc (diff_dst_tz, platform::MKLDNNGetDataType<T>(),
240
- mkldnn::memory::format::nchw);
277
+ // Retrieve src_memory/dst_memory saved in forward pass
278
+ auto src_memory =
279
+ std::static_pointer_cast<memory>(dev_ctx.GetBlob (key_pool_src_mem_p));
280
+ PADDLE_ENFORCE (src_memory != nullptr ,
281
+ " Fail to find src_memory in device context" );
241
282
// Retrieve pool_pd/pool_workspace_memory from device context
242
283
auto pool_pd =
243
284
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
244
285
dev_ctx.GetBlob (key_pool_pd));
245
286
PADDLE_ENFORCE (pool_pd != nullptr ,
246
287
" Fail to find pool_pd in device context" );
247
-
248
- auto workspace_memory = std::static_pointer_cast<mkldnn::memory>(
288
+ auto workspace_memory = std::static_pointer_cast<memory>(
249
289
dev_ctx.GetBlob (key_pool_workspace_memory));
250
290
PADDLE_ENFORCE (workspace_memory != nullptr ,
251
291
" Fail to find workspace_memory in device context" );
252
292
253
- auto pool_diff_src_memory_p = std::make_shared<memory>(memory (
254
- {diff_src_md, mkldnn_engine}, static_cast <void *>(in_x_grad_data)));
255
- dev_ctx.SetBlob (key_pool_diff_src_mem_p, pool_diff_src_memory_p);
256
-
257
- auto pool_diff_dst_memory_p = std::make_shared<memory>(
258
- memory ({diff_dst_md, mkldnn_engine},
259
- static_cast <void *>(const_cast <T*>(out_grad_data))));
260
- dev_ctx.SetBlob (key_pool_diff_dst_mem_p, pool_diff_dst_memory_p);
293
+ // create memory descriptors for pooling
294
+ auto diff_src_md = src_memory.get ()->get_primitive_desc ().desc ();
295
+ auto diff_dst_md = dst_memory.get ()->get_primitive_desc ().desc ();
261
296
262
297
auto pool_bwd_desc = mkldnn::pooling_backward::desc (
263
298
pooling_type == " max" ? mkldnn::algorithm::pooling_max
@@ -267,35 +302,74 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
267
302
auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc (
268
303
pool_bwd_desc, mkldnn_engine, *pool_pd);
269
304
305
+ // reorder between user_diff_dst and pool diff_dst if needed
306
+ diff_dst_memory = std::make_shared<memory>(user_diff_dst_memory);
307
+ if (memory::primitive_desc (dst_memory->get_primitive_desc ()) !=
308
+ user_diff_dst_memory.get_primitive_desc ()) {
309
+ diff_dst_memory =
310
+ std::make_shared<memory>(dst_memory.get ()->get_primitive_desc ());
311
+ reorder_diff_dst = reorder (user_diff_dst_memory, *diff_dst_memory);
312
+ is_diff_dst_reordered = true ;
313
+ }
314
+
315
+ diff_src_memory = std::make_shared<memory>(
316
+ pool_bwd_pd.diff_src_primitive_desc (), in_x_grad_data);
317
+
318
+ dev_ctx.SetBlob (key_pool_diff_src_mem_p, diff_src_memory);
319
+ dev_ctx.SetBlob (key_pool_diff_dst_mem_p, diff_dst_memory);
320
+
270
321
pool_bwd_p = std::make_shared<pooling_backward>(
271
- pool_bwd_pd, *(pool_diff_dst_memory_p .get ()), *workspace_memory,
272
- *(pool_diff_src_memory_p ));
322
+ pool_bwd_pd, *(diff_dst_memory .get ()), *workspace_memory,
323
+ *(diff_src_memory ));
273
324
dev_ctx.SetBlob (key_pool_bwd_p, pool_bwd_p);
325
+
274
326
} else {
275
327
// Primitives already exist
276
- auto pool_diff_src_memory_p = std::static_pointer_cast<memory>(
328
+ diff_src_memory = std::static_pointer_cast<memory>(
277
329
dev_ctx.GetBlob (key_pool_diff_src_mem_p));
278
- PADDLE_ENFORCE (pool_diff_src_memory_p != nullptr ,
330
+ PADDLE_ENFORCE (diff_src_memory != nullptr ,
279
331
" Fail to find pooling src mem_p in device context" );
280
- auto pool_diff_dst_memory_p = std::static_pointer_cast<memory>(
332
+ diff_dst_memory = std::static_pointer_cast<memory>(
281
333
dev_ctx.GetBlob (key_pool_diff_dst_mem_p));
282
- PADDLE_ENFORCE (pool_diff_dst_memory_p != nullptr ,
334
+ PADDLE_ENFORCE (diff_dst_memory != nullptr ,
283
335
" Fail to find pooling dst mem_p in device context" );
284
- pool_diff_src_memory_p->set_data_handle (
285
- reinterpret_cast <void *>(in_x_grad_data));
286
- pool_diff_dst_memory_p->set_data_handle (const_cast <T*>(out_grad_data));
336
+
337
+ diff_src_memory->set_data_handle (reinterpret_cast <void *>(in_x_grad_data));
338
+ diff_dst_memory->set_data_handle (const_cast <T*>(out_grad_data));
339
+
340
+ // reorder between user_diff_dst and pool diff_dst if needed
341
+ if (memory::primitive_desc (dst_memory->get_primitive_desc ()) !=
342
+ user_diff_dst_memory.get_primitive_desc ()) {
343
+ diff_dst_memory =
344
+ std::make_shared<memory>(dst_memory.get ()->get_primitive_desc ());
345
+ reorder_diff_dst = reorder (user_diff_dst_memory, *diff_dst_memory);
346
+ is_diff_dst_reordered = true ;
347
+ }
287
348
}
288
349
350
+ in_x_grad_format = (memory::format)diff_src_memory->get_primitive_desc ()
351
+ .desc ()
352
+ .data .format ;
353
+
289
354
// push primitive to stream and wait until it's executed
290
- std::vector<mkldnn::primitive> pipeline{*(pool_bwd_p.get ())};
355
+ std::vector<mkldnn::primitive> pipeline;
356
+ if (is_diff_dst_reordered) {
357
+ pipeline.push_back (reorder_diff_dst);
358
+ }
359
+ pipeline.push_back (*(pool_bwd_p.get ()));
291
360
mkldnn::stream (mkldnn::stream::kind::eager).submit (pipeline).wait ();
361
+
362
+ in_x_grad->set_layout (DataLayout::kMKLDNN );
363
+ in_x_grad->set_format (in_x_grad_format);
292
364
} // Compute()
293
365
};
294
366
295
367
} // namespace operators
296
368
} // namespace paddle
297
369
370
+ namespace ops = paddle::operators;
371
+
298
372
REGISTER_OP_KERNEL (pool2d, MKLDNN, ::paddle::platform::CPUPlace,
299
- paddle::operators ::PoolMKLDNNOpKernel<float >);
373
+ ops ::PoolMKLDNNOpKernel<float >);
300
374
REGISTER_OP_KERNEL (pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
301
- paddle::operators ::PoolMKLDNNGradOpKernel<float >);
375
+ ops ::PoolMKLDNNGradOpKernel<float >);
0 commit comments