@@ -71,22 +71,22 @@ void check(const float* x, const float* y, float* z, int w) {
71
71
static mkldnn::memory::format StringToMKLDNNFormat (std::string& format) {
72
72
std::transform (format.begin (), format.end (), format.begin (), ::tolower);
73
73
74
- if (!format.compare (" nchw" )) {
74
+ if (!format.compare (" nchw" )) {
75
75
return memory::format::nchw;
76
- } else if (!format.compare (" nchw16c" )) {
76
+ } else if (!format.compare (" nchw16c" )) {
77
77
return memory::format::nChw16c;
78
- } else if (!format.compare (" nchw8c" )) {
78
+ } else if (!format.compare (" nchw8c" )) {
79
79
return memory::format::nChw8c;
80
- } else if (!format.compare (" nhwc" )) {
80
+ } else if (!format.compare (" nhwc" )) {
81
81
return memory::format::nhwc;
82
82
} else {
83
83
return memory::format::any;
84
84
}
85
85
}
86
86
87
87
static void UpdateDataFormat (const framework::ExecutionContext& ctx,
88
- framework::Tensor* tensor, const char * attribute) {
89
- if (ctx.op ().HasAttr (attribute)) {
88
+ framework::Tensor* tensor, const char * attribute) {
89
+ if (ctx.op ().HasAttr (attribute)) {
90
90
auto format_as_string = ctx.Attr <std::string>(attribute);
91
91
auto format = StringToMKLDNNFormat (format_as_string);
92
92
if (format != memory::format::any) {
@@ -98,19 +98,19 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx,
98
98
template <typename T>
99
99
static void ReorderInput (framework::Tensor* tensor,
100
100
const platform::Place& place,
101
- const mkldnn::engine& engine,
102
- bool isFourDim) {
101
+ const mkldnn::engine& engine, bool isFourDim) {
103
102
using platform::to_void_cast;
104
103
auto dims = paddle::framework::vectorize2int (tensor->dims ());
105
104
framework::Tensor out_tensor;
106
105
out_tensor.Resize (tensor->dims ());
107
106
out_tensor.set_format (isFourDim ? memory::format::nchw : memory::format::nc);
108
107
out_tensor.set_layout (tensor->layout ());
109
- mkldnn::memory input_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
110
- tensor->format ()}, engine}, to_void_cast<T>(tensor->data <T>())};
111
- mkldnn::memory output_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
112
- out_tensor.format ()}, engine},
113
- to_void_cast<T>(out_tensor.mutable_data <T>(place))};
108
+ mkldnn::memory input_memory = {
109
+ {{dims, platform::MKLDNNGetDataType<T>(), tensor->format ()}, engine},
110
+ to_void_cast<T>(tensor->data <T>())};
111
+ mkldnn::memory output_memory = {
112
+ {{dims, platform::MKLDNNGetDataType<T>(), out_tensor.format ()}, engine},
113
+ to_void_cast<T>(out_tensor.mutable_data <T>(place))};
114
114
platform::Reorder (input_memory, output_memory);
115
115
tensor->ShareDataWith (out_tensor);
116
116
}
@@ -163,21 +163,19 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
163
163
vector_mul mul;
164
164
165
165
using mul_func_t =
166
- void (*)(const float *, const float *, float *, int , int );
166
+ void (*)(const float *, const float *, float *, int , int );
167
167
168
- mul_func_t mul_func = (mul_func_t ) mul.getCode ();
168
+ mul_func_t mul_func = (mul_func_t )mul.getCode ();
169
169
170
- #pragma omp parallel for collapse(2)
170
+ #pragma omp parallel for collapse(2)
171
171
for (int ni = 0 ; ni < n; ni++) {
172
172
for (int ci = 0 ; ci < C; ci++) {
173
173
auto ptr_x =
174
- x_data + ni * C * h * w * simd_width +
175
- ci * h * w * simd_width;
174
+ x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
176
175
177
176
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
178
177
auto ptr_z =
179
- z_data + ni * C * h * w * simd_width +
180
- ci * h * w * simd_width;
178
+ z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
181
179
182
180
mul_func (ptr_x, ptr_y, ptr_z, h, w);
183
181
}
@@ -189,18 +187,20 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
189
187
} else {
190
188
// Fallback to naive version:
191
189
const bool are_inputs_in_same_format = x->format () == y->format ();
192
- const bool is_x_nchw= x->format () == memory::format::nchw;
190
+ const bool is_x_nchw = x->format () == memory::format::nchw;
193
191
const bool is_x_nc = x->format () == memory::format::nc;
194
- const bool is_y_nchw= y->format () == memory::format::nchw;
192
+ const bool is_y_nchw = y->format () == memory::format::nchw;
195
193
const bool is_y_nc = y->format () == memory::format::nc;
196
- if (!are_inputs_in_same_format) {
194
+ if (!are_inputs_in_same_format) {
197
195
using platform::MKLDNNDeviceContext;
198
196
auto & dev_ctx = ctx.template device_context <MKLDNNDeviceContext>();
199
197
const auto & mkldnn_engine = dev_ctx.GetEngine ();
200
- if (!(is_x_nchw || is_x_nc))
201
- ReorderInput<T>((Tensor*)x, ctx.GetPlace (), mkldnn_engine, x->dims ().size () == 4 );
202
- if (!(is_y_nchw || is_y_nc))
203
- ReorderInput<T>((Tensor*)y, ctx.GetPlace (), mkldnn_engine, y->dims ().size () == 4 );
198
+ if (!(is_x_nchw || is_x_nc))
199
+ ReorderInput<T>((Tensor*)x, ctx.GetPlace (), mkldnn_engine,
200
+ x->dims ().size () == 4 );
201
+ if (!(is_y_nchw || is_y_nc))
202
+ ReorderInput<T>((Tensor*)y, ctx.GetPlace (), mkldnn_engine,
203
+ y->dims ().size () == 4 );
204
204
}
205
205
206
206
auto mul_func = [](T a, T b) -> T { return a * b; };
0 commit comments