Skip to content

Commit b3fd9da

Browse files
authored
Merge pull request #11101 from mozga-intel/mozga-intel/Pool_mkldnn_layout
MKLDNN layout: Support for pool operator
2 parents 627d7a6 + 36031cb commit b3fd9da

File tree

1 file changed

+128
-54
lines changed

1 file changed

+128
-54
lines changed

paddle/fluid/operators/pool_mkldnn_op.cc

Lines changed: 128 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,14 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace operators {
2020

21-
using mkldnn::memory; // Note: paddle has also "memory" namespace
22-
using mkldnn::pooling_forward;
21+
using framework::DataLayout;
22+
using mkldnn::memory;
2323
using mkldnn::pooling_backward;
24+
using mkldnn::pooling_forward;
25+
using mkldnn::primitive;
26+
using mkldnn::reorder;
27+
using mkldnn::stream;
28+
using platform::to_void_cast;
2429

2530
// Generate keys for storing/retriving primitives for this operator
2631
// TODO(jczaja): Make hashing function more optimial
@@ -55,8 +60,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
5560
const Tensor* input = ctx.Input<Tensor>("X");
5661
Tensor* output = ctx.Output<Tensor>("Out");
5762

58-
// Get an unique name from "argument" name of "Out" variable
59-
// This name will be used as key when saving info into device context
63+
PADDLE_ENFORCE(input->layout() == DataLayout::kMKLDNN &&
64+
input->format() != memory::format::format_undef,
65+
"Wrong layout/format set for Input tensor");
6066

6167
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
6268
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
@@ -82,6 +88,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
8288
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
8389
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
8490

91+
auto input_format = input->format();
92+
memory::format output_format{memory::format::format_undef};
93+
8594
const std::string key = gethash(src_tz, pooling_type, ksize, strides,
8695
paddings, ctx.op().Output("Out"));
8796
const std::string key_pool_p = key + "@pool_p";
@@ -94,16 +103,17 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
94103
auto pool_p =
95104
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p));
96105
if (pool_p == nullptr) {
97-
// TODO(pzelazko-intel): support more formats
106+
auto src_md = platform::MKLDNNMemDesc(
107+
src_tz, platform::MKLDNNGetDataType<T>(), input_format);
98108

99-
auto src_md =
100-
platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(),
101-
mkldnn::memory::format::nchw);
102-
auto dst_md =
103-
platform::MKLDNNMemDesc(dst_tz, platform::MKLDNNGetDataType<T>(),
104-
mkldnn::memory::format::nchw);
109+
/* create memory descriptor for pooling without specified format
110+
* ('any') which lets a primitive (pooling in this case) choose
111+
* the memory format preferred for best performance
112+
*/
113+
auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32,
114+
mkldnn::memory::format::any);
105115

106-
std::shared_ptr<pooling_forward::primitive_desc> pool_pd =
116+
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
107117
CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize,
108118
pooling_type, mkldnn_engine);
109119

@@ -116,20 +126,22 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
116126
// save pool_workspace_memory to be referred in backward path
117127
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
118128

119-
auto pool_src_memory_p = std::make_shared<memory>(
120-
memory::primitive_desc{src_md, mkldnn_engine},
121-
static_cast<void*>(const_cast<T*>(input_data)));
122-
dev_ctx.SetBlob(key_pool_src_mem_p, pool_src_memory_p);
129+
auto src_memory = std::make_shared<memory>(pool_pd->src_primitive_desc(),
130+
to_void_cast<T>(input_data));
131+
auto dst_memory =
132+
std::make_shared<memory>(pool_pd->dst_primitive_desc(), output_data);
123133

124-
auto pool_dst_memory_p = std::make_shared<memory>(
125-
memory::primitive_desc{dst_md, mkldnn_engine},
126-
static_cast<void*>(output_data));
127-
dev_ctx.SetBlob(key_pool_dst_mem_p, pool_dst_memory_p);
134+
dev_ctx.SetBlob(key_pool_src_mem_p, src_memory);
135+
dev_ctx.SetBlob(key_pool_dst_mem_p, dst_memory);
136+
137+
pool_p = std::make_shared<pooling_forward>(*pool_pd, *(src_memory.get()),
138+
*(dst_memory.get()),
139+
*workspace_memory);
128140

129-
pool_p = std::make_shared<pooling_forward>(
130-
*pool_pd, *(pool_src_memory_p.get()), *(pool_dst_memory_p.get()),
131-
*workspace_memory);
132141
dev_ctx.SetBlob(key_pool_p, pool_p);
142+
143+
output_format =
144+
(memory::format)dst_memory->get_primitive_desc().desc().data.format;
133145
} else {
134146
// Primitives already exist
135147
auto pool_src_memory_p =
@@ -140,14 +152,20 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
140152
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
141153
PADDLE_ENFORCE(pool_dst_memory_p != nullptr,
142154
"Fail to find pooling dst mem_p in device context");
143-
pool_src_memory_p->set_data_handle(
144-
reinterpret_cast<void*>(const_cast<T*>(input_data)));
155+
pool_src_memory_p->set_data_handle(to_void_cast<T>(input_data));
145156
pool_dst_memory_p->set_data_handle(output_data);
157+
158+
output_format = (memory::format)pool_dst_memory_p->get_primitive_desc()
159+
.desc()
160+
.data.format;
146161
}
147162

148163
// push primitive to stream and wait until it's executed
149164
std::vector<mkldnn::primitive> pipeline{*(pool_p.get())};
150-
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
165+
stream(stream::kind::eager).submit(pipeline).wait();
166+
167+
output->set_layout(DataLayout::kMKLDNN);
168+
output->set_format(output_format);
151169
}
152170

153171
private:
@@ -194,6 +212,13 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
194212
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
195213
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
196214

215+
PADDLE_ENFORCE(in_x->layout() == DataLayout::kMKLDNN &&
216+
in_x->format() != memory::format::format_undef,
217+
"Wrong layout/format set for Input X tensor");
218+
PADDLE_ENFORCE(out_grad->layout() == DataLayout::kMKLDNN &&
219+
out_grad->format() != memory::format::format_undef,
220+
"Wrong layout/format set for Input output_grad tensor");
221+
197222
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
198223
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
199224
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
@@ -212,6 +237,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
212237

213238
const T* out_grad_data = out_grad->data<T>();
214239
T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace());
240+
memory::format in_x_grad_format{memory::format::format_undef};
215241

216242
std::vector<int> diff_src_tz =
217243
paddle::framework::vectorize2int(in_x_grad->dims());
@@ -225,39 +251,48 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
225251
const std::string key_pool_bwd_p = key + "@pool_bwd_p";
226252
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p";
227253
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";
254+
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
255+
const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p";
228256
const std::string key_pool_pd = key + "@pool_pd";
229257
const std::string key_pool_workspace_memory =
230258
key + "@pool_workspace_memory";
231259

260+
auto user_diff_dst_memory =
261+
memory({{{diff_dst_tz}, memory::data_type::f32, out_grad->format()},
262+
mkldnn_engine},
263+
to_void_cast<T>(out_grad_data));
264+
265+
std::shared_ptr<memory> diff_src_memory;
266+
std::shared_ptr<memory> diff_dst_memory;
267+
auto dst_memory =
268+
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
269+
PADDLE_ENFORCE(dst_memory != nullptr,
270+
"Fail to find dst_memory in device context");
271+
272+
primitive reorder_diff_dst;
273+
bool is_diff_dst_reordered = false;
232274
auto pool_bwd_p = std::static_pointer_cast<pooling_backward>(
233275
dev_ctx.GetBlob(key_pool_bwd_p));
234276
if (pool_bwd_p == nullptr) {
235-
auto diff_src_md =
236-
platform::MKLDNNMemDesc(diff_src_tz, platform::MKLDNNGetDataType<T>(),
237-
mkldnn::memory::format::nchw);
238-
auto diff_dst_md =
239-
platform::MKLDNNMemDesc(diff_dst_tz, platform::MKLDNNGetDataType<T>(),
240-
mkldnn::memory::format::nchw);
277+
// Retrieve src_memory/dst_memory saved in forward pass
278+
auto src_memory =
279+
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_src_mem_p));
280+
PADDLE_ENFORCE(src_memory != nullptr,
281+
"Fail to find src_memory in device context");
241282
// Retrieve pool_pd/pool_workspace_memory from device context
242283
auto pool_pd =
243284
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
244285
dev_ctx.GetBlob(key_pool_pd));
245286
PADDLE_ENFORCE(pool_pd != nullptr,
246287
"Fail to find pool_pd in device context");
247-
248-
auto workspace_memory = std::static_pointer_cast<mkldnn::memory>(
288+
auto workspace_memory = std::static_pointer_cast<memory>(
249289
dev_ctx.GetBlob(key_pool_workspace_memory));
250290
PADDLE_ENFORCE(workspace_memory != nullptr,
251291
"Fail to find workspace_memory in device context");
252292

253-
auto pool_diff_src_memory_p = std::make_shared<memory>(memory(
254-
{diff_src_md, mkldnn_engine}, static_cast<void*>(in_x_grad_data)));
255-
dev_ctx.SetBlob(key_pool_diff_src_mem_p, pool_diff_src_memory_p);
256-
257-
auto pool_diff_dst_memory_p = std::make_shared<memory>(
258-
memory({diff_dst_md, mkldnn_engine},
259-
static_cast<void*>(const_cast<T*>(out_grad_data))));
260-
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, pool_diff_dst_memory_p);
293+
// create memory descriptors for pooling
294+
auto diff_src_md = src_memory.get()->get_primitive_desc().desc();
295+
auto diff_dst_md = dst_memory.get()->get_primitive_desc().desc();
261296

262297
auto pool_bwd_desc = mkldnn::pooling_backward::desc(
263298
pooling_type == "max" ? mkldnn::algorithm::pooling_max
@@ -267,35 +302,74 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
267302
auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc(
268303
pool_bwd_desc, mkldnn_engine, *pool_pd);
269304

305+
// reorder between user_diff_dst and pool diff_dst if needed
306+
diff_dst_memory = std::make_shared<memory>(user_diff_dst_memory);
307+
if (memory::primitive_desc(dst_memory->get_primitive_desc()) !=
308+
user_diff_dst_memory.get_primitive_desc()) {
309+
diff_dst_memory =
310+
std::make_shared<memory>(dst_memory.get()->get_primitive_desc());
311+
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
312+
is_diff_dst_reordered = true;
313+
}
314+
315+
diff_src_memory = std::make_shared<memory>(
316+
pool_bwd_pd.diff_src_primitive_desc(), in_x_grad_data);
317+
318+
dev_ctx.SetBlob(key_pool_diff_src_mem_p, diff_src_memory);
319+
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, diff_dst_memory);
320+
270321
pool_bwd_p = std::make_shared<pooling_backward>(
271-
pool_bwd_pd, *(pool_diff_dst_memory_p.get()), *workspace_memory,
272-
*(pool_diff_src_memory_p));
322+
pool_bwd_pd, *(diff_dst_memory.get()), *workspace_memory,
323+
*(diff_src_memory));
273324
dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p);
325+
274326
} else {
275327
// Primitives already exist
276-
auto pool_diff_src_memory_p = std::static_pointer_cast<memory>(
328+
diff_src_memory = std::static_pointer_cast<memory>(
277329
dev_ctx.GetBlob(key_pool_diff_src_mem_p));
278-
PADDLE_ENFORCE(pool_diff_src_memory_p != nullptr,
330+
PADDLE_ENFORCE(diff_src_memory != nullptr,
279331
"Fail to find pooling src mem_p in device context");
280-
auto pool_diff_dst_memory_p = std::static_pointer_cast<memory>(
332+
diff_dst_memory = std::static_pointer_cast<memory>(
281333
dev_ctx.GetBlob(key_pool_diff_dst_mem_p));
282-
PADDLE_ENFORCE(pool_diff_dst_memory_p != nullptr,
334+
PADDLE_ENFORCE(diff_dst_memory != nullptr,
283335
"Fail to find pooling dst mem_p in device context");
284-
pool_diff_src_memory_p->set_data_handle(
285-
reinterpret_cast<void*>(in_x_grad_data));
286-
pool_diff_dst_memory_p->set_data_handle(const_cast<T*>(out_grad_data));
336+
337+
diff_src_memory->set_data_handle(reinterpret_cast<void*>(in_x_grad_data));
338+
diff_dst_memory->set_data_handle(const_cast<T*>(out_grad_data));
339+
340+
// reorder between user_diff_dst and pool diff_dst if needed
341+
if (memory::primitive_desc(dst_memory->get_primitive_desc()) !=
342+
user_diff_dst_memory.get_primitive_desc()) {
343+
diff_dst_memory =
344+
std::make_shared<memory>(dst_memory.get()->get_primitive_desc());
345+
reorder_diff_dst = reorder(user_diff_dst_memory, *diff_dst_memory);
346+
is_diff_dst_reordered = true;
347+
}
287348
}
288349

350+
in_x_grad_format = (memory::format)diff_src_memory->get_primitive_desc()
351+
.desc()
352+
.data.format;
353+
289354
// push primitive to stream and wait until it's executed
290-
std::vector<mkldnn::primitive> pipeline{*(pool_bwd_p.get())};
355+
std::vector<mkldnn::primitive> pipeline;
356+
if (is_diff_dst_reordered) {
357+
pipeline.push_back(reorder_diff_dst);
358+
}
359+
pipeline.push_back(*(pool_bwd_p.get()));
291360
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
361+
362+
in_x_grad->set_layout(DataLayout::kMKLDNN);
363+
in_x_grad->set_format(in_x_grad_format);
292364
} // Compute()
293365
};
294366

295367
} // namespace operators
296368
} // namespace paddle
297369

370+
namespace ops = paddle::operators;
371+
298372
REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace,
299-
paddle::operators::PoolMKLDNNOpKernel<float>);
373+
ops::PoolMKLDNNOpKernel<float>);
300374
REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
301-
paddle::operators::PoolMKLDNNGradOpKernel<float>);
375+
ops::PoolMKLDNNGradOpKernel<float>);

0 commit comments

Comments
 (0)