@@ -23,6 +23,13 @@ using paddle::framework::Tensor;
23
23
using paddle::platform::MKLDNNDeviceContext;
24
24
25
25
namespace {
26
+ std::string gethash (const mkldnn::memory::dims &operand_dims,
27
+ const mkldnn::algorithm algorithm) {
28
+ return std::string (std::to_string (operand_dims[0 ]) + " -" +
29
+ std::to_string (operand_dims[1 ]) + " -" +
30
+ std::to_string (algorithm));
31
+ }
32
+
26
33
template <typename T, typename ExecContext>
27
34
void eltwise_forward (const ExecContext &ctx, mkldnn::algorithm algorithm,
28
35
const T alpha = 0 , const T beta = 0 ) {
@@ -44,37 +51,53 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
44
51
" Input dim must be with 2 or 4" );
45
52
std::vector<int > src_tz = framework::vectorize2int (src->dims ());
46
53
47
- // create memory description
48
- auto data_md = src_tz.size () == 2
49
- ? platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
50
- mkldnn::memory::format::nc)
51
- : platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
52
- mkldnn::memory::format::nchw);
53
-
54
- // create memory primitives
55
- auto src_memory = std::make_shared<mkldnn::memory>(
56
- mkldnn::memory ({data_md, mkldnn_engine},
57
- static_cast <void *>(const_cast <float *>(src_data))));
58
- // save source memory to device context to be referred in backward path
59
- dev_ctx.SetBlob (" InputX@eltwise_pd" , src_memory);
60
- auto dst_memory =
61
- mkldnn::memory ({data_md, mkldnn_engine},
62
- static_cast <void *>(const_cast <float *>(dst_data)));
63
-
64
- auto forward_desc = mkldnn::eltwise_forward::desc (
65
- mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
66
-
67
- // save prim desc into global device context to be referred in backward path
68
- const std::string key = ctx.op ().Output (" Out" );
69
- const std::string key_eltwise_pd = key + " @eltwise_pd" ;
70
- auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
71
- forward_desc, mkldnn_engine);
72
- dev_ctx.SetBlob (key_eltwise_pd, forward_pd);
73
-
74
- auto eltwise = mkldnn::eltwise_forward (*forward_pd, *src_memory, dst_memory);
54
+ const std::string key = gethash (src_tz, algorithm);
55
+ const std::string key_src_mem = key + " @eltwise_src_mem" ;
56
+ const std::string key_dst_mem = key + " @eltwise_dst_mem" ;
57
+ const std::string key_fwd = key + " @eltwise_fwd" ;
58
+
59
+ std::shared_ptr<void > p_src_mem = dev_ctx.GetBlob (key_src_mem);
60
+ std::shared_ptr<void > p_dst_mem = dev_ctx.GetBlob (key_dst_mem);
61
+ std::shared_ptr<void > p_fwd = dev_ctx.GetBlob (key_fwd);
62
+
63
+ if (p_src_mem == nullptr || p_dst_mem == nullptr || p_fwd == nullptr ) {
64
+ // create memory description
65
+ auto data_md = src_tz.size () == 2
66
+ ? platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
67
+ mkldnn::memory::format::nc)
68
+ : platform::MKLDNNMemDesc (src_tz, mkldnn::memory::f32 ,
69
+ mkldnn::memory::format::nchw);
70
+
71
+ // create memory primitives
72
+ p_src_mem = std::make_shared<mkldnn::memory>(
73
+ mkldnn::memory ({data_md, mkldnn_engine},
74
+ static_cast <void *>(const_cast <float *>(src_data))));
75
+ dev_ctx.SetBlob (key_src_mem, p_src_mem);
76
+
77
+ p_dst_mem = std::make_shared<mkldnn::memory>(
78
+ mkldnn::memory ({data_md, mkldnn_engine},
79
+ static_cast <void *>(const_cast <float *>(dst_data))));
80
+ dev_ctx.SetBlob (key_dst_mem, p_dst_mem);
81
+
82
+ auto fwd_desc = mkldnn::eltwise_forward::desc (
83
+ mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
84
+ auto p_fwd_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
85
+ fwd_desc, mkldnn_engine);
86
+ p_fwd = std::make_shared<mkldnn::eltwise_forward>(
87
+ *(p_fwd_pd.get ()), *(static_cast <mkldnn::memory *>(p_src_mem.get ())),
88
+ *(static_cast <mkldnn::memory *>(p_dst_mem.get ())));
89
+ dev_ctx.SetBlob (key_fwd, p_fwd);
90
+ } else {
91
+ std::static_pointer_cast<mkldnn::memory>(p_src_mem)->set_data_handle (
92
+ reinterpret_cast <void *>(const_cast <T *>(src_data)));
93
+
94
+ std::static_pointer_cast<mkldnn::memory>(p_dst_mem)->set_data_handle (
95
+ reinterpret_cast <void *>(const_cast <T *>(dst_data)));
96
+ }
75
97
76
98
// push primitive to stream and wait until it's executed
77
- std::vector<mkldnn::primitive> pipeline = {eltwise};
99
+ std::vector<mkldnn::primitive> pipeline = {
100
+ *(static_cast <mkldnn::eltwise_forward::primitive *>(p_fwd.get ()))};
78
101
mkldnn::stream (mkldnn::stream::kind::eager).submit (pipeline).wait ();
79
102
}
80
103
@@ -85,7 +108,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
85
108
const auto &mkldnn_engine = dev_ctx.GetEngine ();
86
109
87
110
// get buffers
88
- const auto *x = ctx.template Input <Tensor>(" Out" );
111
+ const auto *out = ctx.template Input <Tensor>(" Out" );
89
112
90
113
auto *dout = ctx.template Input <Tensor>(framework::GradVarName (" Out" ));
91
114
const auto *diff_dst = dout->template data <T>();
@@ -95,7 +118,12 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
95
118
const T *diff_src = dx->template mutable_data <T>(ctx.GetPlace ());
96
119
97
120
// get memory dim
98
- std::vector<int > src_tz = framework::vectorize2int (x->dims ());
121
+ std::vector<int > src_tz = framework::vectorize2int (out->dims ());
122
+
123
+ const std::string key = gethash (src_tz, algorithm);
124
+ const std::string key_src_mem = key + " @eltwise_src_mem" ;
125
+ const std::string key_dst_mem = key + " @eltwise_dst_mem" ;
126
+ const std::string key_fwd = key + " @eltwise_fwd" ;
99
127
100
128
// create memory description
101
129
auto data_md = src_tz.size () == 2
@@ -105,8 +133,8 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
105
133
mkldnn::memory::format::nchw);
106
134
107
135
// retrieve source memory from device context
108
- const std::shared_ptr<void > src_memory = dev_ctx.GetBlob (" InputX@eltwise_pd " );
109
- auto *p_src_memory = static_cast <mkldnn::memory *>(src_memory .get ());
136
+ const std::shared_ptr<void > src_mem = dev_ctx.GetBlob (key_src_mem );
137
+ auto *p_src_mem = static_cast <mkldnn::memory *>(src_mem .get ());
110
138
111
139
// create memory primitives
112
140
auto diff_src_memory =
@@ -120,9 +148,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
120
148
mkldnn::eltwise_backward::desc (algorithm, data_md, data_md, alpha, beta);
121
149
122
150
// retrieve eltwise primitive desc from device context
123
- const std::string key = ctx.op ().Input (" Out" );
124
- const std::string key_eltwise_pd = key + " @eltwise_pd" ;
125
- const std::shared_ptr<void > forward_pd = dev_ctx.GetBlob (key_eltwise_pd);
151
+ const std::shared_ptr<void > forward_pd = dev_ctx.GetBlob (key_fwd);
126
152
PADDLE_ENFORCE (forward_pd != nullptr ,
127
153
" Fail to find eltwise_pd in device context" );
128
154
auto *p_forward_pd =
@@ -131,8 +157,8 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
131
157
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc (
132
158
backward_desc, mkldnn_engine, *p_forward_pd);
133
159
134
- auto eltwise_bwd = mkldnn::eltwise_backward (
135
- eltwise_bwd_prim_desc, *p_src_memory, diff_dst_memory, diff_src_memory);
160
+ auto eltwise_bwd = mkldnn::eltwise_backward (eltwise_bwd_prim_desc, *p_src_mem,
161
+ diff_dst_memory, diff_src_memory);
136
162
137
163
// push primitive to stream and wait until it's executed
138
164
std::vector<mkldnn::primitive> pipeline = {eltwise_bwd};
0 commit comments