@@ -95,6 +95,26 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx,
95
95
}
96
96
}
97
97
98
+ template <typename T>
99
+ static void ReorderInput (framework::Tensor* tensor,
100
+ const platform::Place& place,
101
+ const mkldnn::engine& engine,
102
+ bool isFourDim) {
103
+ using platform::to_void_cast;
104
+ auto dims = paddle::framework::vectorize2int (tensor->dims ());
105
+ framework::Tensor out_tensor;
106
+ out_tensor.Resize (tensor->dims ());
107
+ out_tensor.set_format (isFourDim ? memory::format::nchw : memory::format::nc);
108
+ out_tensor.set_layout (tensor->layout ());
109
+ mkldnn::memory input_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
110
+ tensor->format ()}, engine}, to_void_cast<T>(tensor->data <T>())};
111
+ mkldnn::memory output_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
112
+ out_tensor.format ()}, engine},
113
+ to_void_cast<T>(out_tensor.mutable_data <T>(place))};
114
+ platform::Reorder (input_memory, output_memory);
115
+ tensor->ShareDataWith (out_tensor);
116
+ }
117
+
98
118
template <typename T>
99
119
class ElementwiseMulMKLDNNKernel : public framework ::OpKernel<T> {
100
120
public:
@@ -111,63 +131,78 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
111
131
112
132
auto x_dims = x->dims ();
113
133
auto y_dims_untrimmed = y->dims ();
134
+ auto x_int_dims = paddle::framework::vectorize2int (x_dims);
114
135
115
136
UpdateDataFormat (ctx, (Tensor*)x, " x_data_format" );
116
137
UpdateDataFormat (ctx, (Tensor*)y, " y_data_format" );
117
138
118
- if (x->format () == memory::format::nChw16c && y->format () == memory::format::nc) {
119
- if (x_dims != y_dims_untrimmed) {
120
- int pre, n, post;
121
- get_mid_dims (x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
139
+ const bool are_dims_divisable = !(x_int_dims[1 ] % 16 );
140
+ const bool is_x_format_correct = x->format () == memory::format::nChw16c;
141
+ const bool is_y_format_correct = y->format () == memory::format::nc;
142
+ if (is_x_format_correct && is_y_format_correct && are_dims_divisable) {
143
+ int pre, n, post;
144
+ get_mid_dims (x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
122
145
123
- if (post == 1 ) {
124
- PADDLE_THROW (" Not implemented when post is 1" );
125
- } else {
126
- // Just check whether it works for RE-Resnext.
127
- PADDLE_ENFORCE_EQ (x_dims.size (), 4 , " X should have 4 dimensions" );
146
+ if (post == 1 ) {
147
+ PADDLE_THROW (" Not implemented when post is 1" );
148
+ } else {
149
+ // Just check whether it works for RE-Resnext.
150
+ PADDLE_ENFORCE_EQ (x_dims.size (), 4 , " X should have 4 dimensions" );
128
151
129
- int n = x_dims[0 ];
130
- int c = x_dims[1 ];
131
- int h = x_dims[2 ];
132
- int w = x_dims[3 ];
152
+ int n = x_dims[0 ];
153
+ int c = x_dims[1 ];
154
+ int h = x_dims[2 ];
155
+ int w = x_dims[3 ];
133
156
134
- PADDLE_ENFORCE (y_dims_untrimmed[0 ] == n && y_dims_untrimmed[1 ] == c,
135
- " Y should be in nc format" );
157
+ PADDLE_ENFORCE (y_dims_untrimmed[0 ] == n && y_dims_untrimmed[1 ] == c,
158
+ " Y should be in nc format" );
136
159
137
- constexpr int simd_width = 16 ;
138
- int C = c / simd_width;
160
+ constexpr int simd_width = 16 ;
161
+ int C = c / simd_width;
139
162
140
- vector_mul mul;
163
+ vector_mul mul;
141
164
142
- using mul_func_t =
143
- void (*)(const float *, const float *, float *, int , int );
165
+ using mul_func_t =
166
+ void (*)(const float *, const float *, float *, int , int );
144
167
145
- mul_func_t mul_func = (mul_func_t ) mul.getCode ();
168
+ mul_func_t mul_func = (mul_func_t ) mul.getCode ();
146
169
147
- #pragma omp parallel for collapse(2)
148
- for (int ni = 0 ; ni < n; ni++) {
149
- for (int ci = 0 ; ci < C; ci++) {
150
- auto ptr_x =
151
- x_data + ni * C * h * w * simd_width +
152
- ci * h * w * simd_width;
170
+ #pragma omp parallel for collapse(2)
171
+ for (int ni = 0 ; ni < n; ni++) {
172
+ for (int ci = 0 ; ci < C; ci++) {
173
+ auto ptr_x =
174
+ x_data + ni * C * h * w * simd_width +
175
+ ci * h * w * simd_width;
153
176
154
- auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
155
- auto ptr_z =
156
- z_data + ni * C * h * w * simd_width +
157
- ci * h * w * simd_width;
177
+ auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
178
+ auto ptr_z =
179
+ z_data + ni * C * h * w * simd_width +
180
+ ci * h * w * simd_width;
158
181
159
- mul_func (ptr_x, ptr_y, ptr_z, h, w);
160
- }
182
+ mul_func (ptr_x, ptr_y, ptr_z, h, w);
161
183
}
162
184
}
163
-
164
- z->set_layout (DataLayout::kMKLDNN );
165
- z->set_format (x->format ());
166
- } else {
167
- PADDLE_THROW (" Not implemented when dims are equal" );
168
185
}
186
+
187
+ z->set_layout (DataLayout::kMKLDNN );
188
+ z->set_format (x->format ());
169
189
} else {
170
190
// Fallback to naive version:
191
+ const bool are_inputs_in_same_format = x->format () == y->format ();
192
+ const bool is_x_nchw= x->format () == memory::format::nchw;
193
+ const bool is_x_nc = x->format () == memory::format::nc;
194
+ const bool is_y_nchw= y->format () == memory::format::nchw;
195
+ const bool is_y_nc = y->format () == memory::format::nc;
196
+ if (!are_inputs_in_same_format) {
197
+ using platform::MKLDNNDeviceContext;
198
+ auto & dev_ctx = ctx.template device_context <MKLDNNDeviceContext>();
199
+ const auto & mkldnn_engine = dev_ctx.GetEngine ();
200
+ if (!(is_x_nchw || is_x_nc))
201
+ ReorderInput<T>((Tensor*)x, ctx.GetPlace (), mkldnn_engine, x->dims ().size () == 4 );
202
+ if (!(is_y_nchw || is_y_nc))
203
+ ReorderInput<T>((Tensor*)y, ctx.GetPlace (), mkldnn_engine, y->dims ().size () == 4 );
204
+ }
205
+
171
206
auto mul_func = [](T a, T b) -> T { return a * b; };
172
207
173
208
TransformFunctor<decltype (mul_func), T,
0 commit comments