Skip to content

Commit 32929cd

Browse files
committed
Cache input data
1 parent 0aa0192 commit 32929cd

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

paddle/fluid/operators/activation_mkldnn_op.cc

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,18 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
5858
std::vector<int> src_tz = framework::vectorize2int(src->dims());
5959

6060
const std::string key = gethash(src_tz, algorithm);
61+
const std::string key_src_data = key + "@eltwise_fwd_src_data";
6162
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
6263
const std::string key_dst_mem = key + "@eltwise_fwd_dst_mem";
6364
const std::string key_fwd = key + "@eltwise_fwd";
6465

6566
auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
6667
dev_ctx.GetBlob(key_fwd));
6768

69+
// save input data to be referred in backward path
70+
auto p_src_data = std::make_shared<const T *>(src_data);
71+
dev_ctx.SetBlob(key_src_data, p_src_data);
72+
6873
if (p_fwd == nullptr) {
6974
// create memory description
7075
auto data_md = src_tz.size() == 2
@@ -131,11 +136,19 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
131136
std::vector<int> src_tz = framework::vectorize2int(out->dims());
132137

133138
const std::string key = gethash(src_tz, algorithm);
134-
135139
const std::string key_diff_src_mem = key + "@eltwise_diff_src_mem";
136140
const std::string key_diff_dst_mem = key + "@eltwise_diff_dst_mem";
137141
const std::string key_grad = key + "@eltwise_grad";
138142

143+
const std::string key_src_data = key + "@eltwise_fwd_src_data";
144+
const auto p_src_data =
145+
std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data));
146+
147+
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
148+
auto p_src_mem =
149+
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
150+
p_src_mem->set_data_handle(*p_src_data.get());
151+
139152
auto p_grad = std::static_pointer_cast<mkldnn::eltwise_forward::primitive>(
140153
dev_ctx.GetBlob(key_grad));
141154

@@ -167,9 +180,6 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
167180
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
168181
bwd_desc, mkldnn_engine, *p_fwd_pd);
169182

170-
const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
171-
const std::shared_ptr<void> p_src_mem = dev_ctx.GetBlob(key_src_mem);
172-
173183
p_grad = std::make_shared<mkldnn::eltwise_backward>(
174184
eltwise_bwd_prim_desc, *static_cast<mkldnn::memory *>(p_src_mem.get()),
175185
*(static_cast<mkldnn::memory *>(p_diff_dst_mem.get())),

0 commit comments

Comments
 (0)