@@ -26,9 +26,9 @@ using paddle::platform::MKLDNNMemDesc;
26
26
27
27
using mkldnn::memory; // Note: paddle has also "memory" namespace
28
28
using mkldnn::primitive;
29
- using mkldnn::softmax_forward;
30
- using mkldnn::softmax_backward;
31
29
using mkldnn::prop_kind;
30
+ using mkldnn::softmax_backward;
31
+ using mkldnn::softmax_forward;
32
32
using mkldnn::stream;
33
33
using platform::to_void_cast;
34
34
@@ -113,17 +113,27 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
113
113
auto mkldnn_engine = dev_ctx.GetEngine ();
114
114
const Tensor* input = ctx.Input <Tensor>(" X" );
115
115
Tensor* output = ctx.Output <Tensor>(" Out" );
116
- PADDLE_ENFORCE (input->dims ().size () == 2UL ,
117
- " The input of softmax op must be a 2D matrix." );
118
- const T* input_data = input->data <T>();
119
- // allocate memory for output
120
- T* output_data = output->mutable_data <T>(ctx.GetPlace ());
121
- std::vector<int > src_tz = paddle::framework::vectorize2int (input->dims ());
122
- std::vector<int > dst_tz = paddle::framework::vectorize2int (output->dims ());
123
- // MKL-DNN does support softmax over selected axis. Having 2D Tensor,
124
- // we will make normalization after final eg. axis: 1
125
- PADDLE_ENFORCE (((src_tz[0 ] == dst_tz[0 ]) && (src_tz[1 ] == dst_tz[1 ])),
126
- " Softmax input and output dimensions should match" );
116
+ PADDLE_ENFORCE_EQ (
117
+ input->dims (), output->dims (),
118
+ " The shape of softmax's input and output must be identical." );
119
+
120
+ // make sure 'output' holds memory, which will be shared by
121
+ // 'flattened_output' later.
122
+ output->mutable_data <T>(ctx.GetPlace ());
123
+
124
+ // flatten input and output to 2-D matrixs
125
+ auto dims = input->dims (); // input and output share the same shape
126
+ auto flattened_dims = framework::flatten_to_2d (dims, dims.size () - 1 );
127
+ framework::Tensor flattened_input;
128
+ framework::Tensor flattened_output;
129
+ flattened_input.ShareDataWith (*input).Resize (flattened_dims);
130
+ flattened_output.ShareDataWith (*output).Resize (flattened_dims);
131
+
132
+ const T* input_data = flattened_input.data <T>();
133
+ T* output_data = flattened_output.mutable_data <T>(ctx.GetPlace ());
134
+
135
+ std::vector<int > src_tz = paddle::framework::vectorize2int (flattened_dims);
136
+ std::vector<int > dst_tz = src_tz;
127
137
// Same memory descriptor to be used for input and output
128
138
memory::dims softmax_tz = {src_tz[0 ], src_tz[1 ]};
129
139
// Generate keys for storing/retriving primitives for this operator
@@ -174,23 +184,34 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
174
184
auto & dev_ctx = ctx.template device_context <MKLDNNDeviceContext>();
175
185
auto mkldnn_engine = dev_ctx.GetEngine ();
176
186
const Tensor* output = ctx.Input <Tensor>(" Out" );
177
- const T* dst_data = output->data <T>();
178
-
179
187
auto * dout = ctx.template Input <Tensor>(framework::GradVarName (" Out" ));
180
- const auto * diff_dst_ptr = dout->template data <T>();
181
-
182
188
auto * dx =
183
189
ctx.template Output <framework::Tensor>(framework::GradVarName (" X" ));
184
- T* diff_src_ptr = dx->template mutable_data <T>(ctx.GetPlace ());
185
190
186
- std::vector<int > dst_tz = paddle::framework::vectorize2int (output->dims ());
191
+ PADDLE_ENFORCE_EQ (
192
+ dout->dims (), dx->dims (),
193
+ " The shape of softmax_grad's input and output must be identical." );
194
+
195
+ // make sure 'dx' holds memory, which will be shared by 'flattened_dx'
196
+ // later.
197
+ dx->template mutable_data <T>(ctx.GetPlace ());
198
+
199
+ auto dims = dout->dims (); // input and output share the same shape
200
+ auto flattened_dims = framework::flatten_to_2d (dims, dims.size () - 1 );
201
+ framework::Tensor flattened_output;
202
+ framework::Tensor flattened_dout;
203
+ framework::Tensor flattened_dx;
204
+ flattened_output.ShareDataWith (*output).Resize (flattened_dims);
205
+ flattened_dout.ShareDataWith (*dout).Resize (flattened_dims);
206
+ flattened_dx.ShareDataWith (*dx).Resize (flattened_dims);
207
+
208
+ const T* dst_data = flattened_output.data <T>();
209
+ const T* diff_dst_ptr = flattened_dout.template data <T>();
210
+ T* diff_src_ptr = flattened_dx.template mutable_data <T>(ctx.GetPlace ());
211
+
212
+ std::vector<int > dst_tz = paddle::framework::vectorize2int (flattened_dims);
187
213
std::vector<int > src_tz (dst_tz);
188
- PADDLE_ENFORCE (output->dims ().size () == 2UL ,
189
- " The input of softmax op must be a 2D matrix." );
190
- // MKL-DNN does support softmax over selected axis. Having 2D Tensor,
191
- // we will make normalization after final eg. axis: 1
192
- PADDLE_ENFORCE (((src_tz[0 ] == dst_tz[0 ]) && (src_tz[1 ] == dst_tz[1 ])),
193
- " Softmax input and output dimensions should match" );
214
+
194
215
// Same memory descriptor to be used for input and output
195
216
memory::dims softmax_tz = {src_tz[0 ], src_tz[1 ]};
196
217
// Currently only supports NC data format
0 commit comments