Skip to content

Commit 85b6bb5

Browse files
authored
Merge pull request #10747 from jczaja/prv-mkldnn-pooling-reuse
Reuse of pooling mkldnn primitives
2 parents e526dd5 + 5f13330 commit 85b6bb5

File tree

2 files changed

+163
-76
lines changed

2 files changed

+163
-76
lines changed

paddle/fluid/operators/pool_mkldnn_op.cc

Lines changed: 153 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,26 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace operators {
2020

21+
using mkldnn::memory; // Note: paddle has also "memory" namespace
22+
using mkldnn::pooling_forward;
23+
using mkldnn::pooling_backward;
24+
25+
// Generate keys for storing/retriving primitives for this operator
26+
// TODO(jczaja): Make hashing function more optimial
27+
static std::string gethash(memory::dims& input_dims, std::string& pooling_type,
28+
std::vector<int>& ksize, std::vector<int>& strides,
29+
std::vector<int>& paddings, std::string suffix) {
30+
auto dims2str = [](memory::dims& operand_dims) {
31+
std::string dstr = "";
32+
for (size_t i = 0; i < operand_dims.size(); ++i) {
33+
dstr += std::to_string(operand_dims[i]) + "-";
34+
}
35+
return dstr;
36+
};
37+
return dims2str(input_dims) + dims2str(ksize) + dims2str(strides) +
38+
dims2str(paddings) + pooling_type + suffix;
39+
}
40+
2141
template <typename T>
2242
class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
2343
public:
@@ -34,10 +54,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
3454

3555
// Get an unique name from "argument" name of "Out" variable
3656
// This name will be used as key when saving info into device context
37-
const std::string key = ctx.op().Output("Out");
38-
const std::string key_pool_pd = key + "@pool_pd";
39-
const std::string key_pool_workspace_memory =
40-
key + "@pool_workspace_memory";
4157

4258
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
4359
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
@@ -63,37 +79,71 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
6379
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
6480
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
6581

66-
// TODO(pzelazko-intel): support more formats
67-
auto src_md = platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
68-
mkldnn::memory::format::nchw);
69-
auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32,
70-
mkldnn::memory::format::nchw);
71-
72-
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
73-
CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize,
74-
pooling_type, mkldnn_engine);
75-
76-
// save pool_pd into global device context to be referred in backward path
77-
dev_ctx.SetBlob(key_pool_pd, pool_pd);
78-
79-
std::shared_ptr<mkldnn::memory> workspace_memory =
80-
CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine);
81-
82-
// save pool_workspace_memory to be referred in backward path
83-
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
84-
85-
auto src_memory =
86-
mkldnn::memory({src_md, mkldnn_engine},
87-
static_cast<void*>(const_cast<T*>(input_data)));
88-
auto dst_memory =
89-
mkldnn::memory({dst_md, mkldnn_engine},
90-
static_cast<void*>(const_cast<T*>(output_data)));
82+
const std::string key = gethash(src_tz, pooling_type, ksize, strides,
83+
paddings, ctx.op().Output("Out"));
84+
const std::string key_pool_p = key + "@pool_p";
85+
const std::string key_pool_pd = key + "@pool_pd";
86+
const std::string key_pool_src_mem_p = key + "@pool_src_mem_p";
87+
const std::string key_pool_dst_mem_p = key + "@pool_dst_mem_p";
88+
const std::string key_pool_workspace_memory =
89+
key + "@pool_workspace_memory";
9190

92-
auto pool_prim = mkldnn::pooling_forward(*pool_pd, src_memory, dst_memory,
93-
*workspace_memory);
91+
auto pool_p =
92+
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p));
93+
if (pool_p == nullptr) {
94+
// TODO(pzelazko-intel): support more formats
95+
96+
auto src_md =
97+
platform::MKLDNNMemDesc(src_tz, platform::MKLDNNGetDataType<T>(),
98+
mkldnn::memory::format::nchw);
99+
auto dst_md =
100+
platform::MKLDNNMemDesc(dst_tz, platform::MKLDNNGetDataType<T>(),
101+
mkldnn::memory::format::nchw);
102+
103+
std::shared_ptr<pooling_forward::primitive_desc> pool_pd =
104+
CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize,
105+
pooling_type, mkldnn_engine);
106+
107+
// save pool_pd into global device context to be referred in backward path
108+
dev_ctx.SetBlob(key_pool_pd, pool_pd);
109+
110+
std::shared_ptr<mkldnn::memory> workspace_memory =
111+
CreateWorkspaceMemory(pool_pd, pooling_type, mkldnn_engine);
112+
113+
// save pool_workspace_memory to be referred in backward path
114+
dev_ctx.SetBlob(key_pool_workspace_memory, workspace_memory);
115+
116+
auto pool_src_memory_p = std::make_shared<memory>(
117+
memory::primitive_desc{src_md, mkldnn_engine},
118+
static_cast<void*>(const_cast<T*>(input_data)));
119+
dev_ctx.SetBlob(key_pool_src_mem_p, pool_src_memory_p);
120+
121+
auto pool_dst_memory_p = std::make_shared<memory>(
122+
memory::primitive_desc{dst_md, mkldnn_engine},
123+
static_cast<void*>(output_data));
124+
dev_ctx.SetBlob(key_pool_dst_mem_p, pool_dst_memory_p);
125+
126+
pool_p = std::make_shared<pooling_forward>(
127+
*pool_pd, *(pool_src_memory_p.get()), *(pool_dst_memory_p.get()),
128+
*workspace_memory);
129+
dev_ctx.SetBlob(key_pool_p, pool_p);
130+
} else {
131+
// Primitives already exist
132+
auto pool_src_memory_p =
133+
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_src_mem_p));
134+
PADDLE_ENFORCE(pool_src_memory_p != nullptr,
135+
"Fail to find pooling src mem_p in device context");
136+
auto pool_dst_memory_p =
137+
std::static_pointer_cast<memory>(dev_ctx.GetBlob(key_pool_dst_mem_p));
138+
PADDLE_ENFORCE(pool_dst_memory_p != nullptr,
139+
"Fail to find pooling dst mem_p in device context");
140+
pool_src_memory_p->set_data_handle(
141+
reinterpret_cast<void*>(const_cast<T*>(input_data)));
142+
pool_dst_memory_p->set_data_handle(output_data);
143+
}
94144

95145
// push primitive to stream and wait until it's executed
96-
std::vector<mkldnn::primitive> pipeline{pool_prim};
146+
std::vector<mkldnn::primitive> pipeline{*(pool_p.get())};
97147
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
98148
}
99149

@@ -120,9 +170,10 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
120170
mkldnn::memory::primitive_desc workspace_md =
121171
pooling_type == "max"
122172
? pool_pd->workspace_primitive_desc()
123-
: mkldnn::memory::primitive_desc(
124-
{{}, mkldnn::memory::f32, mkldnn::memory::format::nchw},
125-
engine);
173+
: mkldnn::memory::primitive_desc({{},
174+
platform::MKLDNNGetDataType<T>(),
175+
mkldnn::memory::format::nchw},
176+
engine);
126177

127178
auto p_workspace_memory = new mkldnn::memory(workspace_md);
128179
return std::unique_ptr<mkldnn::memory>(p_workspace_memory);
@@ -140,13 +191,6 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
140191
const Tensor* out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
141192
Tensor* in_x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
142193

143-
// Get an unique name from "argument" name of "Out" variable
144-
// This name will be used as key when referring info from device context
145-
const std::string key = ctx.op().Input("Out");
146-
const std::string key_pool_pd = key + "@pool_pd";
147-
const std::string key_pool_workspace_memory =
148-
key + "@pool_workspace_memory";
149-
150194
std::string pooling_type = ctx.Attr<std::string>("pooling_type");
151195
std::vector<int> ksize = ctx.Attr<std::vector<int>>("ksize");
152196
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
@@ -171,43 +215,76 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
171215
std::vector<int> diff_dst_tz =
172216
paddle::framework::vectorize2int(out_grad->dims());
173217

174-
auto diff_src_md = platform::MKLDNNMemDesc(diff_src_tz, mkldnn::memory::f32,
175-
mkldnn::memory::format::nchw);
176-
auto diff_dst_md = platform::MKLDNNMemDesc(diff_dst_tz, mkldnn::memory::f32,
177-
mkldnn::memory::format::nchw);
178-
179-
// Retrieve pool_pd/pool_workspace_memory from device context
180-
auto pool_pd =
181-
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
182-
dev_ctx.GetBlob(key_pool_pd));
183-
PADDLE_ENFORCE(pool_pd != nullptr,
184-
"Fail to find pool_pd in device context");
185-
186-
auto workspace_memory = std::static_pointer_cast<mkldnn::memory>(
187-
dev_ctx.GetBlob(key_pool_workspace_memory));
188-
PADDLE_ENFORCE(workspace_memory != nullptr,
189-
"Fail to find workspace_memory in device context");
190-
191-
auto pool_bwd_desc = mkldnn::pooling_backward::desc(
192-
pooling_type == "max" ? mkldnn::algorithm::pooling_max
193-
: mkldnn::algorithm::pooling_avg,
194-
diff_src_md, diff_dst_md, strides, ksize, paddings, paddings,
195-
mkldnn::padding_kind::zero);
196-
auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc(
197-
pool_bwd_desc, mkldnn_engine, *pool_pd);
198-
199-
auto diff_src_memory =
200-
mkldnn::memory({diff_src_md, mkldnn_engine},
201-
static_cast<void*>(const_cast<T*>(in_x_grad_data)));
202-
auto diff_dst_memory =
203-
mkldnn::memory({diff_dst_md, mkldnn_engine},
204-
static_cast<void*>(const_cast<T*>(out_grad_data)));
218+
// Get an unique name from "argument" name of "Out" variable
219+
// This name will be used as key when referring info from device context
220+
const std::string key = gethash(diff_src_tz, pooling_type, ksize, strides,
221+
paddings, ctx.op().Input("Out"));
222+
const std::string key_pool_bwd_p = key + "@pool_bwd_p";
223+
const std::string key_pool_diff_src_mem_p = key + "@pool_diff_src_mem_p";
224+
const std::string key_pool_diff_dst_mem_p = key + "@pool_diff_dst_mem_p";
225+
const std::string key_pool_pd = key + "@pool_pd";
226+
const std::string key_pool_workspace_memory =
227+
key + "@pool_workspace_memory";
205228

206-
auto bwd_prim = mkldnn::pooling_backward(
207-
pool_bwd_pd, diff_dst_memory, *workspace_memory, diff_src_memory);
229+
auto pool_bwd_p = std::static_pointer_cast<pooling_backward>(
230+
dev_ctx.GetBlob(key_pool_bwd_p));
231+
if (pool_bwd_p == nullptr) {
232+
auto diff_src_md =
233+
platform::MKLDNNMemDesc(diff_src_tz, platform::MKLDNNGetDataType<T>(),
234+
mkldnn::memory::format::nchw);
235+
auto diff_dst_md =
236+
platform::MKLDNNMemDesc(diff_dst_tz, platform::MKLDNNGetDataType<T>(),
237+
mkldnn::memory::format::nchw);
238+
// Retrieve pool_pd/pool_workspace_memory from device context
239+
auto pool_pd =
240+
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
241+
dev_ctx.GetBlob(key_pool_pd));
242+
PADDLE_ENFORCE(pool_pd != nullptr,
243+
"Fail to find pool_pd in device context");
244+
245+
auto workspace_memory = std::static_pointer_cast<mkldnn::memory>(
246+
dev_ctx.GetBlob(key_pool_workspace_memory));
247+
PADDLE_ENFORCE(workspace_memory != nullptr,
248+
"Fail to find workspace_memory in device context");
249+
250+
auto pool_diff_src_memory_p = std::make_shared<memory>(memory(
251+
{diff_src_md, mkldnn_engine}, static_cast<void*>(in_x_grad_data)));
252+
dev_ctx.SetBlob(key_pool_diff_src_mem_p, pool_diff_src_memory_p);
253+
254+
auto pool_diff_dst_memory_p = std::make_shared<memory>(
255+
memory({diff_dst_md, mkldnn_engine},
256+
static_cast<void*>(const_cast<T*>(out_grad_data))));
257+
dev_ctx.SetBlob(key_pool_diff_dst_mem_p, pool_diff_dst_memory_p);
258+
259+
auto pool_bwd_desc = mkldnn::pooling_backward::desc(
260+
pooling_type == "max" ? mkldnn::algorithm::pooling_max
261+
: mkldnn::algorithm::pooling_avg,
262+
diff_src_md, diff_dst_md, strides, ksize, paddings, paddings,
263+
mkldnn::padding_kind::zero);
264+
auto pool_bwd_pd = mkldnn::pooling_backward::primitive_desc(
265+
pool_bwd_desc, mkldnn_engine, *pool_pd);
266+
267+
pool_bwd_p = std::make_shared<pooling_backward>(
268+
pool_bwd_pd, *(pool_diff_dst_memory_p.get()), *workspace_memory,
269+
*(pool_diff_src_memory_p));
270+
dev_ctx.SetBlob(key_pool_bwd_p, pool_bwd_p);
271+
} else {
272+
// Primitives already exist
273+
auto pool_diff_src_memory_p = std::static_pointer_cast<memory>(
274+
dev_ctx.GetBlob(key_pool_diff_src_mem_p));
275+
PADDLE_ENFORCE(pool_diff_src_memory_p != nullptr,
276+
"Fail to find pooling src mem_p in device context");
277+
auto pool_diff_dst_memory_p = std::static_pointer_cast<memory>(
278+
dev_ctx.GetBlob(key_pool_diff_dst_mem_p));
279+
PADDLE_ENFORCE(pool_diff_dst_memory_p != nullptr,
280+
"Fail to find pooling dst mem_p in device context");
281+
pool_diff_src_memory_p->set_data_handle(
282+
reinterpret_cast<void*>(in_x_grad_data));
283+
pool_diff_dst_memory_p->set_data_handle(const_cast<T*>(out_grad_data));
284+
}
208285

209286
// push primitive to stream and wait until it's executed
210-
std::vector<mkldnn::primitive> pipeline{bwd_prim};
287+
std::vector<mkldnn::primitive> pipeline{*(pool_bwd_p.get())};
211288
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
212289
} // Compute()
213290
};

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,5 +71,15 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
7171
return use_mkldnn && platform::is_cpu_place(ctx.GetPlace());
7272
}
7373

74+
template <typename Type>
75+
mkldnn::memory::data_type MKLDNNGetDataType() {
76+
return mkldnn::memory::data_undef;
77+
}
78+
79+
template <>
80+
inline mkldnn::memory::data_type MKLDNNGetDataType<float>() {
81+
return mkldnn::memory::f32;
82+
}
83+
7484
} // namespace platform
7585
} // namespace paddle

0 commit comments

Comments
 (0)