15
15
#include " mkldnn.hpp"
16
16
#include " paddle/fluid/operators/activation_op.h"
17
17
#include " paddle/fluid/operators/mkldnn_activation_op.h"
18
+ #include " paddle/fluid/platform/mkldnn_helper.h"
18
19
19
20
namespace paddle {
20
21
namespace operators {
@@ -25,9 +26,14 @@ using paddle::platform::MKLDNNDeviceContext;
25
26
namespace {
26
27
std::string gethash (const mkldnn::memory::dims &operand_dims,
27
28
const mkldnn::algorithm algorithm) {
28
- return std::string (std::to_string (operand_dims[0 ]) + " -" +
29
- std::to_string (operand_dims[1 ]) + " -" +
30
- std::to_string (algorithm));
29
+ auto dim2str = [](const mkldnn::memory::dims &operand_dims) {
30
+ std::string dstr = " " ;
31
+ for (size_t i = 0 ; i < operand_dims.size (); ++i) {
32
+ dstr += std::to_string (operand_dims[i]) + " -" ;
33
+ }
34
+ return dstr;
35
+ };
36
+ return dim2str (operand_dims) + std::to_string (algorithm);
31
37
}
32
38
33
39
template <typename T, typename ExecContext>
@@ -44,23 +50,22 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
44
50
const auto *src_data = src->template data <T>();
45
51
46
52
auto *dst = ctx.template Output <Tensor>(" Out" );
47
- const T *dst_data = dst->template mutable_data <T>(ctx.GetPlace ());
53
+ T *dst_data = dst->template mutable_data <T>(ctx.GetPlace ());
48
54
49
55
// get memory dim
50
56
PADDLE_ENFORCE (src->dims ().size () == 2 || src->dims ().size () == 4 ,
51
57
" Input dim must be with 2 or 4" );
52
58
std::vector<int > src_tz = framework::vectorize2int (src->dims ());
53
59
54
60
const std::string key = gethash (src_tz, algorithm);
55
- const std::string key_src_mem = key + " @eltwise_src_mem " ;
56
- const std::string key_dst_mem = key + " @eltwise_dst_mem " ;
61
+ const std::string key_src_mem = key + " @eltwise_fwd_src_mem " ;
62
+ const std::string key_dst_mem = key + " @eltwise_fwd_dst_mem " ;
57
63
const std::string key_fwd = key + " @eltwise_fwd" ;
58
64
59
- std::shared_ptr<void > p_src_mem = dev_ctx.GetBlob (key_src_mem);
60
- std::shared_ptr<void > p_dst_mem = dev_ctx.GetBlob (key_dst_mem);
61
- std::shared_ptr<void > p_fwd = dev_ctx.GetBlob (key_fwd);
65
+ auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
66
+ dev_ctx.GetBlob (key_fwd));
62
67
63
- if (p_src_mem == nullptr || p_dst_mem == nullptr || p_fwd == nullptr ) {
68
+ if (p_fwd == nullptr ) {
64
69
// create memory description
65
70
auto data_md = src_tz.size () == 2
66
71
? platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
@@ -69,35 +74,40 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
69
74
mkldnn::memory::format::nchw);
70
75
71
76
// create memory primitives
72
- p_src_mem = std::make_shared<mkldnn::memory>(
73
- mkldnn::memory ({data_md, mkldnn_engine},
74
- static_cast <void *>(const_cast <float *>(src_data))));
77
+ auto p_src_mem = std::make_shared<mkldnn::memory>(mkldnn::memory (
78
+ {data_md, mkldnn_engine}, platform::to_void_cast (src_data)));
75
79
dev_ctx.SetBlob (key_src_mem, p_src_mem);
76
80
77
- p_dst_mem = std::make_shared<mkldnn::memory>(
78
- mkldnn::memory ({data_md, mkldnn_engine},
79
- static_cast <void *>(const_cast <float *>(dst_data))));
81
+ auto p_dst_mem = std::make_shared<mkldnn::memory>(mkldnn::memory (
82
+ {data_md, mkldnn_engine}, platform::to_void_cast (dst_data)));
80
83
dev_ctx.SetBlob (key_dst_mem, p_dst_mem);
81
84
82
85
auto fwd_desc = mkldnn::eltwise_forward::desc (
83
86
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
84
87
auto p_fwd_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
85
88
fwd_desc, mkldnn_engine);
89
+ const std::string key_fwd_pd = key + " eltwise_fwd_pd" ;
90
+ dev_ctx.SetBlob (key_fwd_pd, p_fwd_pd);
86
91
p_fwd = std::make_shared<mkldnn::eltwise_forward>(
87
- *(p_fwd_pd.get ()), *(static_cast <mkldnn::memory *>(p_src_mem.get ())),
88
- *(static_cast <mkldnn::memory *>(p_dst_mem.get ())));
92
+ *p_fwd_pd, *(p_src_mem.get ()), *(p_dst_mem.get ()));
89
93
dev_ctx.SetBlob (key_fwd, p_fwd);
90
94
} else {
91
- std::static_pointer_cast<mkldnn::memory>(p_src_mem)->set_data_handle (
92
- reinterpret_cast <void *>(const_cast <T *>(src_data)));
93
-
94
- std::static_pointer_cast<mkldnn::memory>(p_dst_mem)->set_data_handle (
95
- reinterpret_cast <void *>(const_cast <T *>(dst_data)));
95
+ // primitives already exist
96
+ auto p_src_mem =
97
+ std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob (key_src_mem));
98
+ PADDLE_ENFORCE (p_src_mem != nullptr ,
99
+ " Fail to find eltwise p_src_mem in device context." );
100
+ auto p_dst_mem =
101
+ std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob (key_dst_mem));
102
+ PADDLE_ENFORCE (p_dst_mem != nullptr ,
103
+ " Fail to find eltwise p_src_mem in device context." );
104
+
105
+ p_src_mem->set_data_handle (platform::to_void_reinterpret_cast (src_data));
106
+ p_dst_mem->set_data_handle (dst_data);
96
107
}
97
108
98
109
// push primitive to stream and wait until it's executed
99
- std::vector<mkldnn::primitive> pipeline = {
100
- *(static_cast <mkldnn::eltwise_forward::primitive *>(p_fwd.get ()))};
110
+ std::vector<mkldnn::primitive> pipeline = {*(p_fwd.get ())};
101
111
mkldnn::stream (mkldnn::stream::kind::eager).submit (pipeline).wait ();
102
112
}
103
113
@@ -121,47 +131,64 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
121
131
std::vector<int > src_tz = framework::vectorize2int (out->dims ());
122
132
123
133
const std::string key = gethash (src_tz, algorithm);
124
- const std::string key_src_mem = key + " @eltwise_src_mem" ;
125
- const std::string key_dst_mem = key + " @eltwise_dst_mem" ;
126
- const std::string key_fwd = key + " @eltwise_fwd" ;
127
134
128
- // create memory description
129
- auto data_md = src_tz.size () == 2
130
- ? platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
131
- mkldnn::memory::format::nc)
132
- : platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
133
- mkldnn::memory::format::nchw);
134
-
135
- // retrieve source memory from device context
136
- const std::shared_ptr<void > src_mem = dev_ctx.GetBlob (key_src_mem);
137
- auto *p_src_mem = static_cast <mkldnn::memory *>(src_mem.get ());
138
-
139
- // create memory primitives
140
- auto diff_src_memory =
141
- mkldnn::memory ({data_md, mkldnn_engine},
142
- static_cast <void *>(const_cast <float *>(diff_src)));
143
- auto diff_dst_memory =
144
- mkldnn::memory ({data_md, mkldnn_engine},
145
- static_cast <void *>(const_cast <float *>(diff_dst)));
146
-
147
- auto backward_desc =
148
- mkldnn::eltwise_backward::desc (algorithm, data_md, data_md, alpha, beta);
149
-
150
- // retrieve eltwise primitive desc from device context
151
- const std::shared_ptr<void > forward_pd = dev_ctx.GetBlob (key_fwd);
152
- PADDLE_ENFORCE (forward_pd != nullptr ,
153
- " Fail to find eltwise_pd in device context" );
154
- auto *p_forward_pd =
155
- static_cast <mkldnn::eltwise_forward::primitive_desc *>(forward_pd.get ());
156
-
157
- auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc (
158
- backward_desc, mkldnn_engine, *p_forward_pd);
159
-
160
- auto eltwise_bwd = mkldnn::eltwise_backward (eltwise_bwd_prim_desc, *p_src_mem,
161
- diff_dst_memory, diff_src_memory);
135
+ const std::string key_diff_src_mem = key + " @eltwise_diff_src_mem" ;
136
+ const std::string key_diff_dst_mem = key + " @eltwise_diff_dst_mem" ;
137
+ const std::string key_grad = key + " @eltwise_grad" ;
138
+
139
+ auto p_grad = std::static_pointer_cast<mkldnn::eltwise_forward::primitive>(
140
+ dev_ctx.GetBlob (key_grad));
141
+
142
+ if (p_grad == nullptr ) {
143
+ // create memory description
144
+ auto data_md = src_tz.size () == 2
145
+ ? platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
146
+ mkldnn::memory::format::nc)
147
+ : platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
148
+ mkldnn::memory::format::nchw);
149
+
150
+ // create memory primitives
151
+ std::shared_ptr<void > p_diff_src_mem =
152
+ std::make_shared<mkldnn::memory>(mkldnn::memory (
153
+ {data_md, mkldnn_engine}, platform::to_void_cast (diff_src)));
154
+ dev_ctx.SetBlob (key_diff_src_mem, p_diff_src_mem);
155
+ std::shared_ptr<void > p_diff_dst_mem =
156
+ std::make_shared<mkldnn::memory>(mkldnn::memory (
157
+ {data_md, mkldnn_engine}, platform::to_void_cast (diff_dst)));
158
+ dev_ctx.SetBlob (key_diff_dst_mem, p_diff_dst_mem);
159
+
160
+ auto bwd_desc = mkldnn::eltwise_backward::desc (algorithm, data_md, data_md,
161
+ alpha, beta);
162
+
163
+ const std::string key_fwd_pd = key + " eltwise_fwd_pd" ;
164
+ auto *p_fwd_pd = static_cast <mkldnn::eltwise_forward::primitive_desc *>(
165
+ dev_ctx.GetBlob (key_fwd_pd).get ());
166
+
167
+ auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc (
168
+ bwd_desc, mkldnn_engine, *p_fwd_pd);
169
+
170
+ const std::string key_src_mem = key + " @eltwise_fwd_src_mem" ;
171
+ const std::shared_ptr<void > p_src_mem = dev_ctx.GetBlob (key_src_mem);
172
+
173
+ p_grad = std::make_shared<mkldnn::eltwise_backward>(
174
+ eltwise_bwd_prim_desc, *static_cast <mkldnn::memory *>(p_src_mem.get ()),
175
+ *(static_cast <mkldnn::memory *>(p_diff_dst_mem.get ())),
176
+ *(static_cast <mkldnn::memory *>(p_diff_src_mem.get ())));
177
+ } else {
178
+ // primitives already exist
179
+ auto p_diff_src_mem = std::static_pointer_cast<mkldnn::memory>(
180
+ dev_ctx.GetBlob (key_diff_src_mem));
181
+ auto p_diff_dst_mem = std::static_pointer_cast<mkldnn::memory>(
182
+ dev_ctx.GetBlob (key_diff_dst_mem));
183
+
184
+ p_diff_src_mem->set_data_handle (
185
+ platform::to_void_reinterpret_cast (diff_src));
186
+ p_diff_dst_mem->set_data_handle (
187
+ platform::to_void_reinterpret_cast (diff_dst));
188
+ }
162
189
163
190
// push primitive to stream and wait until it's executed
164
- std::vector<mkldnn::primitive> pipeline = {eltwise_bwd };
191
+ std::vector<mkldnn::primitive> pipeline = {*(p_grad. get ()) };
165
192
mkldnn::stream (mkldnn::stream::kind::eager).submit (pipeline).wait ();
166
193
}
167
194
} // anonymous namespace
0 commit comments