Skip to content

Commit 0aa0192

Browse files
committed
Add backward
1 parent 0cc25a4 commit 0aa0192

File tree

2 files changed

+95
-63
lines changed

2 files changed

+95
-63
lines changed

paddle/fluid/operators/activation_mkldnn_op.cc

Lines changed: 90 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mkldnn.hpp"
1616
#include "paddle/fluid/operators/activation_op.h"
1717
#include "paddle/fluid/operators/mkldnn_activation_op.h"
18+
#include "paddle/fluid/platform/mkldnn_helper.h"
1819

1920
namespace paddle {
2021
namespace operators {
@@ -25,9 +26,14 @@ using paddle::platform::MKLDNNDeviceContext;
2526
namespace {
2627
std::string gethash(const mkldnn::memory::dims &operand_dims,
2728
const mkldnn::algorithm algorithm) {
28-
return std::string(std::to_string(operand_dims[0]) + "-" +
29-
std::to_string(operand_dims[1]) + "-" +
30-
std::to_string(algorithm));
29+
auto dim2str = [](const mkldnn::memory::dims &operand_dims) {
30+
std::string dstr = "";
31+
for (size_t i = 0; i < operand_dims.size(); ++i) {
32+
dstr += std::to_string(operand_dims[i]) + "-";
33+
}
34+
return dstr;
35+
};
36+
return dim2str(operand_dims) + std::to_string(algorithm);
3137
}
3238

3339
template <typename T, typename ExecContext>
@@ -44,23 +50,22 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
4450
const auto *src_data = src->template data<T>();
4551

4652
auto *dst = ctx.template Output<Tensor>("Out");
47-
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
53+
T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
4854

4955
// get memory dim
5056
PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4,
5157
"Input dim must be with 2 or 4");
5258
std::vector<int> src_tz = framework::vectorize2int(src->dims());
5359

5460
const std::string key = gethash(src_tz, algorithm);
55-
const std::string key_src_mem = key + "@eltwise_src_mem";
56-
const std::string key_dst_mem = key + "@eltwise_dst_mem";
61+
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
62+
const std::string key_dst_mem = key + "@eltwise_fwd_dst_mem";
5763
const std::string key_fwd = key + "@eltwise_fwd";
5864

59-
std::shared_ptr<void> p_src_mem = dev_ctx.GetBlob(key_src_mem);
60-
std::shared_ptr<void> p_dst_mem = dev_ctx.GetBlob(key_dst_mem);
61-
std::shared_ptr<void> p_fwd = dev_ctx.GetBlob(key_fwd);
65+
auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
66+
dev_ctx.GetBlob(key_fwd));
6267

63-
if (p_src_mem == nullptr || p_dst_mem == nullptr || p_fwd == nullptr) {
68+
if (p_fwd == nullptr) {
6469
// create memory description
6570
auto data_md = src_tz.size() == 2
6671
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
@@ -69,35 +74,40 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
6974
mkldnn::memory::format::nchw);
7075

7176
// create memory primitives
72-
p_src_mem = std::make_shared<mkldnn::memory>(
73-
mkldnn::memory({data_md, mkldnn_engine},
74-
static_cast<void *>(const_cast<float *>(src_data))));
77+
auto p_src_mem = std::make_shared<mkldnn::memory>(mkldnn::memory(
78+
{data_md, mkldnn_engine}, platform::to_void_cast(src_data)));
7579
dev_ctx.SetBlob(key_src_mem, p_src_mem);
7680

77-
p_dst_mem = std::make_shared<mkldnn::memory>(
78-
mkldnn::memory({data_md, mkldnn_engine},
79-
static_cast<void *>(const_cast<float *>(dst_data))));
81+
auto p_dst_mem = std::make_shared<mkldnn::memory>(mkldnn::memory(
82+
{data_md, mkldnn_engine}, platform::to_void_cast(dst_data)));
8083
dev_ctx.SetBlob(key_dst_mem, p_dst_mem);
8184

8285
auto fwd_desc = mkldnn::eltwise_forward::desc(
8386
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
8487
auto p_fwd_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
8588
fwd_desc, mkldnn_engine);
89+
const std::string key_fwd_pd = key + "eltwise_fwd_pd";
90+
dev_ctx.SetBlob(key_fwd_pd, p_fwd_pd);
8691
p_fwd = std::make_shared<mkldnn::eltwise_forward>(
87-
*(p_fwd_pd.get()), *(static_cast<mkldnn::memory *>(p_src_mem.get())),
88-
*(static_cast<mkldnn::memory *>(p_dst_mem.get())));
92+
*p_fwd_pd, *(p_src_mem.get()), *(p_dst_mem.get()));
8993
dev_ctx.SetBlob(key_fwd, p_fwd);
9094
} else {
91-
std::static_pointer_cast<mkldnn::memory>(p_src_mem)->set_data_handle(
92-
reinterpret_cast<void *>(const_cast<T *>(src_data)));
93-
94-
std::static_pointer_cast<mkldnn::memory>(p_dst_mem)->set_data_handle(
95-
reinterpret_cast<void *>(const_cast<T *>(dst_data)));
95+
// primitives already exist
96+
auto p_src_mem =
97+
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
98+
PADDLE_ENFORCE(p_src_mem != nullptr,
99+
"Fail to find eltwise p_src_mem in device context.");
100+
auto p_dst_mem =
101+
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_dst_mem));
102+
PADDLE_ENFORCE(p_dst_mem != nullptr,
103+
"Fail to find eltwise p_src_mem in device context.");
104+
105+
p_src_mem->set_data_handle(platform::to_void_reinterpret_cast(src_data));
106+
p_dst_mem->set_data_handle(dst_data);
96107
}
97108

98109
// push primitive to stream and wait until it's executed
99-
std::vector<mkldnn::primitive> pipeline = {
100-
*(static_cast<mkldnn::eltwise_forward::primitive *>(p_fwd.get()))};
110+
std::vector<mkldnn::primitive> pipeline = {*(p_fwd.get())};
101111
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
102112
}
103113

@@ -121,47 +131,64 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
121131
std::vector<int> src_tz = framework::vectorize2int(out->dims());
122132

123133
const std::string key = gethash(src_tz, algorithm);
124-
const std::string key_src_mem = key + "@eltwise_src_mem";
125-
const std::string key_dst_mem = key + "@eltwise_dst_mem";
126-
const std::string key_fwd = key + "@eltwise_fwd";
127134

128-
// create memory description
129-
auto data_md = src_tz.size() == 2
130-
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
131-
mkldnn::memory::format::nc)
132-
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
133-
mkldnn::memory::format::nchw);
134-
135-
// retrieve source memory from device context
136-
const std::shared_ptr<void> src_mem = dev_ctx.GetBlob(key_src_mem);
137-
auto *p_src_mem = static_cast<mkldnn::memory *>(src_mem.get());
138-
139-
// create memory primitives
140-
auto diff_src_memory =
141-
mkldnn::memory({data_md, mkldnn_engine},
142-
static_cast<void *>(const_cast<float *>(diff_src)));
143-
auto diff_dst_memory =
144-
mkldnn::memory({data_md, mkldnn_engine},
145-
static_cast<void *>(const_cast<float *>(diff_dst)));
146-
147-
auto backward_desc =
148-
mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta);
149-
150-
// retrieve eltwise primitive desc from device context
151-
const std::shared_ptr<void> forward_pd = dev_ctx.GetBlob(key_fwd);
152-
PADDLE_ENFORCE(forward_pd != nullptr,
153-
"Fail to find eltwise_pd in device context");
154-
auto *p_forward_pd =
155-
static_cast<mkldnn::eltwise_forward::primitive_desc *>(forward_pd.get());
156-
157-
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
158-
backward_desc, mkldnn_engine, *p_forward_pd);
159-
160-
auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, *p_src_mem,
161-
diff_dst_memory, diff_src_memory);
135+
const std::string key_diff_src_mem = key + "@eltwise_diff_src_mem";
136+
const std::string key_diff_dst_mem = key + "@eltwise_diff_dst_mem";
137+
const std::string key_grad = key + "@eltwise_grad";
138+
139+
auto p_grad = std::static_pointer_cast<mkldnn::eltwise_forward::primitive>(
140+
dev_ctx.GetBlob(key_grad));
141+
142+
if (p_grad == nullptr) {
143+
// create memory description
144+
auto data_md = src_tz.size() == 2
145+
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
146+
mkldnn::memory::format::nc)
147+
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
148+
mkldnn::memory::format::nchw);
149+
150+
// create memory primitives
151+
std::shared_ptr<void> p_diff_src_mem =
152+
std::make_shared<mkldnn::memory>(mkldnn::memory(
153+
{data_md, mkldnn_engine}, platform::to_void_cast(diff_src)));
154+
dev_ctx.SetBlob(key_diff_src_mem, p_diff_src_mem);
155+
std::shared_ptr<void> p_diff_dst_mem =
156+
std::make_shared<mkldnn::memory>(mkldnn::memory(
157+
{data_md, mkldnn_engine}, platform::to_void_cast(diff_dst)));
158+
dev_ctx.SetBlob(key_diff_dst_mem, p_diff_dst_mem);
159+
160+
auto bwd_desc = mkldnn::eltwise_backward::desc(algorithm, data_md, data_md,
161+
alpha, beta);
162+
163+
const std::string key_fwd_pd = key + "eltwise_fwd_pd";
164+
auto *p_fwd_pd = static_cast<mkldnn::eltwise_forward::primitive_desc *>(
165+
dev_ctx.GetBlob(key_fwd_pd).get());
166+
167+
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
168+
bwd_desc, mkldnn_engine, *p_fwd_pd);
169+
170+
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
171+
const std::shared_ptr<void> p_src_mem = dev_ctx.GetBlob(key_src_mem);
172+
173+
p_grad = std::make_shared<mkldnn::eltwise_backward>(
174+
eltwise_bwd_prim_desc, *static_cast<mkldnn::memory *>(p_src_mem.get()),
175+
*(static_cast<mkldnn::memory *>(p_diff_dst_mem.get())),
176+
*(static_cast<mkldnn::memory *>(p_diff_src_mem.get())));
177+
} else {
178+
// primitives already exist
179+
auto p_diff_src_mem = std::static_pointer_cast<mkldnn::memory>(
180+
dev_ctx.GetBlob(key_diff_src_mem));
181+
auto p_diff_dst_mem = std::static_pointer_cast<mkldnn::memory>(
182+
dev_ctx.GetBlob(key_diff_dst_mem));
183+
184+
p_diff_src_mem->set_data_handle(
185+
platform::to_void_reinterpret_cast(diff_src));
186+
p_diff_dst_mem->set_data_handle(
187+
platform::to_void_reinterpret_cast(diff_dst));
188+
}
162189

163190
// push primitive to stream and wait until it's executed
164-
std::vector<mkldnn::primitive> pipeline = {eltwise_bwd};
191+
std::vector<mkldnn::primitive> pipeline = {*(p_grad.get())};
165192
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
166193
}
167194
} // anonymous namespace

paddle/fluid/platform/mkldnn_helper.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ void* to_void_cast(const Type* t) {
3838
return static_cast<void*>(const_cast<Type*>(t));
3939
}
4040

41+
template <typename Type>
42+
void* to_void_reinterpret_cast(const Type* t) {
43+
return reinterpret_cast<void*>(const_cast<Type*>(t));
44+
}
45+
4146
template <class Type>
4247
using tf_desc = typename Type::desc;
4348

0 commit comments

Comments
 (0)