@@ -27,8 +27,81 @@ using paddle::platform::MKLDNNMemDesc;
27
27
using mkldnn::memory; // Note: paddle has also "memory" namespace
28
28
using mkldnn::primitive;
29
29
using mkldnn::softmax_forward;
30
+ using mkldnn::softmax_backward;
30
31
using mkldnn::prop_kind;
31
32
using mkldnn::stream;
33
+ using platform::to_void_cast;
34
+
35
+ class SoftmaxMKLDNNHandler : public platform ::MKLDNNHandler {
36
+ public:
37
+ SoftmaxMKLDNNHandler (
38
+ std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd,
39
+ const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
40
+ const std::string& base_key)
41
+ : platform::MKLDNNHandler(dev_ctx, engine, base_key),
42
+ softmax_pd_ (softmax_pd) {}
43
+
44
+ SoftmaxMKLDNNHandler (
45
+ std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd,
46
+ std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd,
47
+ const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
48
+ const std::string& base_key)
49
+ : platform::MKLDNNHandler(dev_ctx, engine, base_key),
50
+ softmax_pd_(softmax_pd),
51
+ softmax_bwd_pd_(softmax_bwd_pd) {
52
+ // If we are in Grad operatgor then update a key with BWD suffix to
53
+ // distinguish from FWD memory primitives
54
+ key_ += " -BWD" ;
55
+ }
56
+
57
+ std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax (
58
+ std::shared_ptr<mkldnn::memory> dst_memory_p,
59
+ std::shared_ptr<mkldnn::memory> src_memory_p) {
60
+ /* Generate key*/
61
+ auto prim_key = key_ + " @softmax_p" ;
62
+
63
+ auto softmax_p = std::static_pointer_cast<mkldnn::softmax_forward>(
64
+ dev_ctx_.GetBlob (prim_key));
65
+ PADDLE_ENFORCE ((softmax_p != nullptr ) || (is_reusing_ == false ),
66
+ " Fail to find softmax primitive in device context" );
67
+ if (softmax_p == nullptr ) {
68
+ softmax_p = std::make_shared<mkldnn::softmax_forward>(
69
+ *(softmax_pd_.get ()),
70
+ *(static_cast <mkldnn::memory*>(src_memory_p.get ())),
71
+ *(static_cast <mkldnn::memory*>(dst_memory_p.get ())));
72
+ dev_ctx_.SetBlob (prim_key, softmax_p);
73
+ } else {
74
+ is_reusing_ = true ;
75
+ }
76
+
77
+ return softmax_p;
78
+ }
79
+
80
+ std::shared_ptr<mkldnn::softmax_backward> AcquireSoftmaxBackward (
81
+ std::shared_ptr<mkldnn::memory> dst_memory_p,
82
+ std::shared_ptr<mkldnn::memory> diff_dst_memory_p,
83
+ std::shared_ptr<mkldnn::memory> diff_src_memory_p) {
84
+ auto prim_key = key_ + " @softmax_bwd_p" ;
85
+ auto softmax_bwd_p = std::static_pointer_cast<mkldnn::softmax_backward>(
86
+ dev_ctx_.GetBlob (prim_key));
87
+ PADDLE_ENFORCE ((softmax_bwd_p != nullptr ) || (is_reusing_ == false ),
88
+ " Fail to find softmax backward primitive in device context" );
89
+ if (softmax_bwd_p == nullptr ) {
90
+ softmax_bwd_p = std::make_shared<mkldnn::softmax_backward>(
91
+ *softmax_bwd_pd_, *(dst_memory_p.get ()), *(diff_dst_memory_p.get ()),
92
+ *(diff_src_memory_p.get ()));
93
+ dev_ctx_.SetBlob (prim_key, softmax_bwd_p);
94
+ } else {
95
+ is_reusing_ = true ;
96
+ }
97
+
98
+ return softmax_bwd_p;
99
+ }
100
+
101
+ private:
102
+ std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd_;
103
+ std::shared_ptr<mkldnn::softmax_backward::primitive_desc> softmax_bwd_pd_;
104
+ };
32
105
33
106
template <typename T>
34
107
class SoftmaxMKLDNNKernel : public paddle ::framework::OpKernel<T> {
@@ -54,56 +127,27 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
54
127
// Same memory descriptor to be used for input and output
55
128
memory::dims softmax_tz = {src_tz[0 ], src_tz[1 ]};
56
129
// Generate keys for storing/retriving primitives for this operator
57
- // TODO(jczaja): Each MKLDNN operator may have diffrent hashing function
58
- auto gethash = [](memory::dims& operand_dims) {
59
- return std::string (std::to_string (operand_dims[0 ]) + " -" +
60
- std::to_string (operand_dims[1 ]));
61
- };
62
- const std::string key = gethash (softmax_tz);
63
- const std::string key_softmax_p = key + " @softmax_p" ;
64
- const std::string key_softmax_src_mem_p = key + " @softmax_src_mem_p" ;
65
- const std::string key_softmax_dst_mem_p = key + " @softmax_dst_mem_p" ;
66
-
67
- std::shared_ptr<void > softmax_p = dev_ctx.GetBlob (key_softmax_p);
68
- if (softmax_p == nullptr ) {
69
- // Currently only NC data format is supported
70
- auto softmax_md =
71
- MKLDNNMemDesc ({softmax_tz}, memory::f32 , memory::format::nc);
72
- // Normalization is made after innermost dimension eg. C out of NC
73
- auto softmax_desc = softmax_forward::desc (prop_kind::forward_scoring,
74
- softmax_md, 1 /* dim: C*/ );
75
- // create memory primitives
76
- auto softmax_src_memory_p = std::make_shared<memory>(
77
- memory::primitive_desc{softmax_md, mkldnn_engine},
78
- static_cast <void *>(const_cast <T*>(input_data)));
79
- dev_ctx.SetBlob (key_softmax_src_mem_p, softmax_src_memory_p);
80
- auto softmax_dst_memory_p = std::make_shared<memory>(
81
- memory::primitive_desc{softmax_md, mkldnn_engine},
82
- static_cast <void *>(output_data));
83
- dev_ctx.SetBlob (key_softmax_dst_mem_p, softmax_dst_memory_p);
84
-
85
- auto softmax_forward_pd =
86
- std::make_shared<softmax_forward::primitive_desc>(softmax_desc,
87
- mkldnn_engine);
88
- softmax_p = std::make_shared<softmax_forward>(
89
- *(softmax_forward_pd.get ()),
90
- *(static_cast <memory*>(softmax_src_memory_p.get ())),
91
- *(static_cast <memory*>(softmax_dst_memory_p.get ())));
92
- dev_ctx.SetBlob (key_softmax_p, softmax_p);
93
- } else {
94
- // Primitives already exist
95
- auto src_memory_p = std::static_pointer_cast<memory>(
96
- dev_ctx.GetBlob (key_softmax_src_mem_p));
97
- PADDLE_ENFORCE (src_memory_p != nullptr ,
98
- " Fail to find softmax src mem_p in device context" );
99
- auto dst_memory_p = std::static_pointer_cast<memory>(
100
- dev_ctx.GetBlob (key_softmax_dst_mem_p));
101
- PADDLE_ENFORCE (dst_memory_p != nullptr ,
102
- " Fail to find softmax dst mem_p in device context" );
103
- src_memory_p->set_data_handle (
104
- reinterpret_cast <void *>(const_cast <T*>(input_data)));
105
- dst_memory_p->set_data_handle (output_data);
106
- }
130
+ const std::string key =
131
+ platform::MKLDNNHandler::GetHash (softmax_tz, ctx.op ().Output (" Out" ));
132
+ const std::string key_softmax_pd = key + " @softmax_pd" ;
133
+
134
+ // Currently only NC data format is supported
135
+ auto softmax_md = MKLDNNMemDesc (
136
+ {softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
137
+ // Normalization is made after innermost dimension eg. C out of NC
138
+ auto softmax_desc = softmax_forward::desc (prop_kind::forward_scoring,
139
+ softmax_md, 1 /* dim: C*/ );
140
+ auto softmax_pd = std::make_shared<mkldnn::softmax_forward::primitive_desc>(
141
+ softmax_desc, mkldnn_engine);
142
+ dev_ctx.SetBlob (key_softmax_pd, softmax_pd);
143
+
144
+ SoftmaxMKLDNNHandler handler (softmax_pd, dev_ctx, mkldnn_engine, key);
145
+ auto softmax_src_memory_p =
146
+ handler.AcquireSrcMemory (softmax_md, to_void_cast<T>(input_data));
147
+ auto softmax_dst_memory_p =
148
+ handler.AcquireDstMemory (softmax_md, to_void_cast<T>(output_data));
149
+ auto softmax_p =
150
+ handler.AcquireSoftmax (softmax_dst_memory_p, softmax_src_memory_p);
107
151
108
152
std::vector<primitive> pipeline{
109
153
*(static_cast <softmax_forward::primitive*>(softmax_p.get ()))};
@@ -120,10 +164,83 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
120
164
}
121
165
};
122
166
167
+ template <typename T>
168
+ class SoftmaxMKLDNNGradKernel : public paddle ::framework::OpKernel<T> {
169
+ public:
170
+ void Compute (const paddle::framework::ExecutionContext& ctx) const override {
171
+ PADDLE_ENFORCE (paddle::platform::is_cpu_place (ctx.GetPlace ()),
172
+ " It must use CPUPlace." );
173
+
174
+ auto & dev_ctx = ctx.template device_context <MKLDNNDeviceContext>();
175
+ auto mkldnn_engine = dev_ctx.GetEngine ();
176
+ const Tensor* output = ctx.Input <Tensor>(" Out" );
177
+ const T* dst_data = output->data <T>();
178
+
179
+ auto * dout = ctx.template Input <Tensor>(framework::GradVarName (" Out" ));
180
+ const auto * diff_dst_ptr = dout->template data <T>();
181
+
182
+ auto * dx =
183
+ ctx.template Output <framework::Tensor>(framework::GradVarName (" X" ));
184
+ T* diff_src_ptr = dx->template mutable_data <T>(ctx.GetPlace ());
185
+
186
+ std::vector<int > dst_tz = paddle::framework::vectorize2int (output->dims ());
187
+ std::vector<int > src_tz (dst_tz);
188
+ PADDLE_ENFORCE (output->dims ().size () == 2UL ,
189
+ " The input of softmax op must be a 2D matrix." );
190
+ // MKL-DNN does support softmax over selected axis. Having 2D Tensor,
191
+ // we will make normalization after final eg. axis: 1
192
+ PADDLE_ENFORCE (((src_tz[0 ] == dst_tz[0 ]) && (src_tz[1 ] == dst_tz[1 ])),
193
+ " Softmax input and output dimensions should match" );
194
+ // Same memory descriptor to be used for input and output
195
+ memory::dims softmax_tz = {src_tz[0 ], src_tz[1 ]};
196
+ // Currently only supports NC data format
197
+ // retrieve eltwise primitive desc from device context
198
+ const std::string key =
199
+ platform::MKLDNNHandler::GetHash (softmax_tz, ctx.op ().Input (" Out" ));
200
+ const std::string key_softmax_pd = key + " @softmax_pd" ;
201
+
202
+ auto softmax_pd =
203
+ std::static_pointer_cast<mkldnn::softmax_forward::primitive_desc>(
204
+ dev_ctx.GetBlob (key_softmax_pd));
205
+ PADDLE_ENFORCE (softmax_pd != nullptr ,
206
+ " Fail to find softmax_pd in device context" );
207
+
208
+ // TODO(jczaja): Add layouts support when there is a need to do so
209
+ // Two dimensional softmax does support NC format
210
+ auto data_softmax_md = MKLDNNMemDesc (
211
+ {softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
212
+ auto diff_softmax_md = MKLDNNMemDesc (
213
+ {softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
214
+ // Normalization is made after innermost dimension eg. C out of NC
215
+ auto softmax_bwd_desc =
216
+ softmax_backward::desc (diff_softmax_md, data_softmax_md, 1 /* dim: C*/ );
217
+ auto softmax_bwd_pd =
218
+ std::make_shared<mkldnn::softmax_backward::primitive_desc>(
219
+ softmax_bwd_desc, mkldnn_engine, *softmax_pd);
220
+
221
+ SoftmaxMKLDNNHandler handler (softmax_pd, softmax_bwd_pd, dev_ctx,
222
+ mkldnn_engine, key);
223
+ auto dst_memory_p =
224
+ handler.AcquireDstMemory (data_softmax_md, to_void_cast<T>(dst_data));
225
+ auto diff_dst_memory_p = handler.AcquireDiffDstMemory (
226
+ diff_softmax_md, to_void_cast<T>(diff_dst_ptr));
227
+ auto diff_src_memory_p = handler.AcquireDiffSrcMemory (
228
+ diff_softmax_md, to_void_cast<T>(diff_src_ptr));
229
+
230
+ // Get primitve from device context
231
+ auto softmax_bwd_p = handler.AcquireSoftmaxBackward (
232
+ dst_memory_p, diff_dst_memory_p, diff_src_memory_p);
233
+
234
+ std::vector<primitive> pipeline{*softmax_bwd_p};
235
+ stream (stream::kind::eager).submit (pipeline).wait ();
236
+ }
237
+ };
123
238
} // namespace operators
124
239
} // namespace paddle
125
240
126
241
namespace ops = paddle::operators;
127
242
128
243
REGISTER_OP_KERNEL (softmax, MKLDNN, ::paddle::platform::CPUPlace,
129
244
ops::SoftmaxMKLDNNKernel<float >);
245
+ REGISTER_OP_KERNEL (softmax_grad, MKLDNN, ::paddle::platform::CPUPlace,
246
+ ops::SoftmaxMKLDNNGradKernel<float >);
0 commit comments