@@ -58,13 +58,18 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
58
58
std::vector<int > src_tz = framework::vectorize2int (src->dims ());
59
59
60
60
const std::string key = gethash (src_tz, algorithm);
61
+ const std::string key_src_data = key + " @eltwise_fwd_src_data" ;
61
62
const std::string key_src_mem = key + " @eltwise_fwd_src_mem" ;
62
63
const std::string key_dst_mem = key + " @eltwise_fwd_dst_mem" ;
63
64
const std::string key_fwd = key + " @eltwise_fwd" ;
64
65
65
66
auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
66
67
dev_ctx.GetBlob (key_fwd));
67
68
69
+ // save input data to be referred in backward path
70
+ auto p_src_data = std::make_shared<const T *>(src_data);
71
+ dev_ctx.SetBlob (key_src_data, p_src_data);
72
+
68
73
if (p_fwd == nullptr ) {
69
74
// create memory description
70
75
auto data_md = src_tz.size () == 2
@@ -131,11 +136,19 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
131
136
std::vector<int > src_tz = framework::vectorize2int (out->dims ());
132
137
133
138
const std::string key = gethash (src_tz, algorithm);
134
-
135
139
const std::string key_diff_src_mem = key + " @eltwise_diff_src_mem" ;
136
140
const std::string key_diff_dst_mem = key + " @eltwise_diff_dst_mem" ;
137
141
const std::string key_grad = key + " @eltwise_grad" ;
138
142
143
+ const std::string key_src_data = key + " @eltwise_fwd_src_data" ;
144
+ const auto p_src_data =
145
+ std::static_pointer_cast<T *>(dev_ctx.GetBlob (key_src_data));
146
+
147
+ const std::string key_src_mem = key + " @eltwise_fwd_src_mem" ;
148
+ auto p_src_mem =
149
+ std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob (key_src_mem));
150
+ p_src_mem->set_data_handle (*p_src_data.get ());
151
+
139
152
auto p_grad = std::static_pointer_cast<mkldnn::eltwise_forward::primitive>(
140
153
dev_ctx.GetBlob (key_grad));
141
154
@@ -167,9 +180,6 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
167
180
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc (
168
181
bwd_desc, mkldnn_engine, *p_fwd_pd);
169
182
170
- const std::string key_src_mem = key + " @eltwise_fwd_src_mem" ;
171
- const std::shared_ptr<void > p_src_mem = dev_ctx.GetBlob (key_src_mem);
172
-
173
183
p_grad = std::make_shared<mkldnn::eltwise_backward>(
174
184
eltwise_bwd_prim_desc, *static_cast <mkldnn::memory *>(p_src_mem.get ()),
175
185
*(static_cast <mkldnn::memory *>(p_diff_dst_mem.get ())),
0 commit comments