Skip to content

Commit 08f63c4

Browse files
committed
MKLDNN elementwise_mul: Lint changes to UT & integration
test=develop
1 parent 73b7cd0 commit 08f63c4

File tree

3 files changed

+50
-40
lines changed

3 files changed

+50
-40
lines changed

paddle/fluid/operators/elementwise/elementwise_op.h

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,19 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
9898
AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.")
9999
.SetDefault(false);
100100
AddAttr<std::string>(
101-
"x_data_format",
102-
"(string, default NCHW) Only used in mkldnn"
103-
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
104-
"Defaults to \"\". Specify the data format of the output data, "
105-
"the input will be transformed automatically. ")
106-
.SetDefault("");
101+
"x_data_format",
102+
"(string, default NCHW) Only used in mkldnn"
103+
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
104+
"Defaults to \"\". Specify the data format of the output data, "
105+
"the input will be transformed automatically. ")
106+
.SetDefault("");
107107
AddAttr<std::string>(
108-
"y_data_format",
109-
"(string, default \"\") Only used in mkldnn"
110-
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
111-
"Defaults to \"\". Specify the data format of the output data, "
112-
"the input will be transformed automatically. ")
113-
.SetDefault("");
108+
"y_data_format",
109+
"(string, default \"\") Only used in mkldnn"
110+
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
111+
"Defaults to \"\". Specify the data format of the output data, "
112+
"the input will be transformed automatically. ")
113+
.SetDefault("");
114114
AddComment(string::Sprintf(R"DOC(
115115
Elementwise %s Operator
116116

paddle/fluid/operators/elementwise_mul_mkldnn_op.cc

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -71,22 +71,22 @@ void check(const float* x, const float* y, float* z, int w) {
7171
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
7272
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
7373

74-
if(!format.compare("nchw")) {
74+
if (!format.compare("nchw")) {
7575
return memory::format::nchw;
76-
} else if(!format.compare("nchw16c")) {
76+
} else if (!format.compare("nchw16c")) {
7777
return memory::format::nChw16c;
78-
} else if(!format.compare("nchw8c")) {
78+
} else if (!format.compare("nchw8c")) {
7979
return memory::format::nChw8c;
80-
} else if(!format.compare("nhwc")) {
80+
} else if (!format.compare("nhwc")) {
8181
return memory::format::nhwc;
8282
} else {
8383
return memory::format::any;
8484
}
8585
}
8686

8787
static void UpdateDataFormat(const framework::ExecutionContext& ctx,
88-
framework::Tensor* tensor, const char* attribute) {
89-
if(ctx.op().HasAttr(attribute)) {
88+
framework::Tensor* tensor, const char* attribute) {
89+
if (ctx.op().HasAttr(attribute)) {
9090
auto format_as_string = ctx.Attr<std::string>(attribute);
9191
auto format = StringToMKLDNNFormat(format_as_string);
9292
if (format != memory::format::any) {
@@ -98,19 +98,19 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx,
9898
template <typename T>
9999
static void ReorderInput(framework::Tensor* tensor,
100100
const platform::Place& place,
101-
const mkldnn::engine& engine,
102-
bool isFourDim) {
101+
const mkldnn::engine& engine, bool isFourDim) {
103102
using platform::to_void_cast;
104103
auto dims = paddle::framework::vectorize2int(tensor->dims());
105104
framework::Tensor out_tensor;
106105
out_tensor.Resize(tensor->dims());
107106
out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc);
108107
out_tensor.set_layout(tensor->layout());
109-
mkldnn::memory input_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
110-
tensor->format()}, engine}, to_void_cast<T>(tensor->data<T>())};
111-
mkldnn::memory output_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
112-
out_tensor.format()}, engine},
113-
to_void_cast<T>(out_tensor.mutable_data<T>(place))};
108+
mkldnn::memory input_memory = {
109+
{{dims, platform::MKLDNNGetDataType<T>(), tensor->format()}, engine},
110+
to_void_cast<T>(tensor->data<T>())};
111+
mkldnn::memory output_memory = {
112+
{{dims, platform::MKLDNNGetDataType<T>(), out_tensor.format()}, engine},
113+
to_void_cast<T>(out_tensor.mutable_data<T>(place))};
114114
platform::Reorder(input_memory, output_memory);
115115
tensor->ShareDataWith(out_tensor);
116116
}
@@ -163,21 +163,19 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
163163
vector_mul mul;
164164

165165
using mul_func_t =
166-
void (*)(const float *, const float *, float *, int, int);
166+
void (*)(const float*, const float*, float*, int, int);
167167

168-
mul_func_t mul_func = (mul_func_t) mul.getCode();
168+
mul_func_t mul_func = (mul_func_t)mul.getCode();
169169

170-
#pragma omp parallel for collapse(2)
170+
#pragma omp parallel for collapse(2)
171171
for (int ni = 0; ni < n; ni++) {
172172
for (int ci = 0; ci < C; ci++) {
173173
auto ptr_x =
174-
x_data + ni * C * h * w * simd_width +
175-
ci * h * w * simd_width;
174+
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
176175

177176
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
178177
auto ptr_z =
179-
z_data + ni * C * h * w * simd_width +
180-
ci * h * w * simd_width;
178+
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
181179

182180
mul_func(ptr_x, ptr_y, ptr_z, h, w);
183181
}
@@ -189,18 +187,20 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
189187
} else {
190188
// Fallback to naive version:
191189
const bool are_inputs_in_same_format = x->format() == y->format();
192-
const bool is_x_nchw= x->format() == memory::format::nchw;
190+
const bool is_x_nchw = x->format() == memory::format::nchw;
193191
const bool is_x_nc = x->format() == memory::format::nc;
194-
const bool is_y_nchw= y->format() == memory::format::nchw;
192+
const bool is_y_nchw = y->format() == memory::format::nchw;
195193
const bool is_y_nc = y->format() == memory::format::nc;
196-
if(!are_inputs_in_same_format) {
194+
if (!are_inputs_in_same_format) {
197195
using platform::MKLDNNDeviceContext;
198196
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
199197
const auto& mkldnn_engine = dev_ctx.GetEngine();
200-
if(!(is_x_nchw || is_x_nc))
201-
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4);
202-
if(!(is_y_nchw || is_y_nc))
203-
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4);
198+
if (!(is_x_nchw || is_x_nc))
199+
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine,
200+
x->dims().size() == 4);
201+
if (!(is_y_nchw || is_y_nc))
202+
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine,
203+
y->dims().size() == 4);
204204
}
205205

206206
auto mul_func = [](T a, T b) -> T { return a * b; };

python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from paddle.fluid.op import Operator
2121
from test_elementwise_mul_op import *
2222

23+
2324
class TestElementwiseMulMKLDNNOp_BroadcastNCHW16c(ElementwiseMulOp):
2425
def init_input_output(self):
2526
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
@@ -49,7 +50,9 @@ def test_check_grad_ingore_x(self):
4950
def test_check_grad_ingore_y(self):
5051
pass
5152

52-
@unittest.skip("Not implemented yet.") # TODO(mgallus): enable when implemented.
53+
54+
@unittest.skip(
55+
"Not implemented yet.") # TODO(mgallus): enable when implemented.
5356
class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp):
5457
def init_input_output(self):
5558
x = np.random.rand(1, 8, 2, 2).astype(self.dtype)
@@ -79,6 +82,7 @@ def test_check_grad_ingore_x(self):
7982
def test_check_grad_ingore_y(self):
8083
pass
8184

85+
8286
class TestElementwiseMulMKLDNNOp_FallbackNCHW(ElementwiseMulOp):
8387
def init_input_output(self):
8488
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
@@ -101,6 +105,7 @@ def test_check_grad_ingore_x(self):
101105
def test_check_grad_ingore_y(self):
102106
pass
103107

108+
104109
class TestElementwiseMulMKLDNNOp_FallbackNCHW16C(ElementwiseMulOp):
105110
def init_input_output(self):
106111
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
@@ -130,6 +135,7 @@ def test_check_grad_ingore_x(self):
130135
def test_check_grad_ingore_y(self):
131136
pass
132137

138+
133139
class TestElementwiseMulMKLDNNOp_FallbackNoReorders(ElementwiseMulOp):
134140
def init_input_output(self):
135141
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
@@ -159,6 +165,7 @@ def test_check_grad_ingore_x(self):
159165
def test_check_grad_ingore_y(self):
160166
pass
161167

168+
162169
class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp):
163170
def init_input_output(self):
164171
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
@@ -187,6 +194,7 @@ def test_check_grad_ingore_x(self):
187194
def test_check_grad_ingore_y(self):
188195
pass
189196

197+
190198
class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp):
191199
def init_input_output(self):
192200
self.y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
@@ -215,6 +223,7 @@ def test_check_grad_ingore_x(self):
215223
def test_check_grad_ingore_y(self):
216224
pass
217225

226+
218227
class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp):
219228
def init_input_output(self):
220229
self.x = np.random.rand(1, 16).astype(self.dtype)
@@ -242,5 +251,6 @@ def test_check_grad_ingore_x(self):
242251
def test_check_grad_ingore_y(self):
243252
pass
244253

254+
245255
if __name__ == '__main__':
246256
unittest.main()

0 commit comments

Comments
 (0)