Skip to content

Commit 49b0932

Browse files
committed
MKLDNN elementwise_mul: Reorder on non-nchw input, fallback on non-16 divisable fm
test=develop
1 parent f820573 commit 49b0932

File tree

2 files changed

+131
-42
lines changed

2 files changed

+131
-42
lines changed

paddle/fluid/operators/elementwise_mul_mkldnn_op.cc

Lines changed: 73 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,26 @@ static void UpdateDataFormat(const framework::ExecutionContext& ctx,
9595
}
9696
}
9797

98+
template <typename T>
99+
static void ReorderInput(framework::Tensor* tensor,
100+
const platform::Place& place,
101+
const mkldnn::engine& engine,
102+
bool isFourDim) {
103+
using platform::to_void_cast;
104+
auto dims = paddle::framework::vectorize2int(tensor->dims());
105+
framework::Tensor out_tensor;
106+
out_tensor.Resize(tensor->dims());
107+
out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc);
108+
out_tensor.set_layout(tensor->layout());
109+
mkldnn::memory input_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
110+
tensor->format()}, engine}, to_void_cast<T>(tensor->data<T>())};
111+
mkldnn::memory output_memory = {{{dims, platform::MKLDNNGetDataType<T>(),
112+
out_tensor.format()}, engine},
113+
to_void_cast<T>(out_tensor.mutable_data<T>(place))};
114+
platform::Reorder(input_memory, output_memory);
115+
tensor->ShareDataWith(out_tensor);
116+
}
117+
98118
template <typename T>
99119
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
100120
public:
@@ -111,63 +131,78 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
111131

112132
auto x_dims = x->dims();
113133
auto y_dims_untrimmed = y->dims();
134+
auto x_int_dims = paddle::framework::vectorize2int(x_dims);
114135

115136
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
116137
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
117138

118-
if (x->format() == memory::format::nChw16c && y->format() == memory::format::nc) {
119-
if (x_dims != y_dims_untrimmed) {
120-
int pre, n, post;
121-
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
139+
const bool are_dims_divisable = !(x_int_dims[1] % 16);
140+
const bool is_x_format_correct = x->format() == memory::format::nChw16c;
141+
const bool is_y_format_correct = y->format() == memory::format::nc;
142+
if (is_x_format_correct && is_y_format_correct && are_dims_divisable) {
143+
int pre, n, post;
144+
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
122145

123-
if (post == 1) {
124-
PADDLE_THROW("Not implemented when post is 1");
125-
} else {
126-
// Just check whether it works for RE-Resnext.
127-
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions");
146+
if (post == 1) {
147+
PADDLE_THROW("Not implemented when post is 1");
148+
} else {
149+
// Just check whether it works for RE-Resnext.
150+
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions");
128151

129-
int n = x_dims[0];
130-
int c = x_dims[1];
131-
int h = x_dims[2];
132-
int w = x_dims[3];
152+
int n = x_dims[0];
153+
int c = x_dims[1];
154+
int h = x_dims[2];
155+
int w = x_dims[3];
133156

134-
PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c,
135-
"Y should be in nc format");
157+
PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c,
158+
"Y should be in nc format");
136159

137-
constexpr int simd_width = 16;
138-
int C = c / simd_width;
160+
constexpr int simd_width = 16;
161+
int C = c / simd_width;
139162

140-
vector_mul mul;
163+
vector_mul mul;
141164

142-
using mul_func_t =
143-
void (*)(const float *, const float *, float *, int, int);
165+
using mul_func_t =
166+
void (*)(const float *, const float *, float *, int, int);
144167

145-
mul_func_t mul_func = (mul_func_t) mul.getCode();
168+
mul_func_t mul_func = (mul_func_t) mul.getCode();
146169

147-
#pragma omp parallel for collapse(2)
148-
for (int ni = 0; ni < n; ni++) {
149-
for (int ci = 0; ci < C; ci++) {
150-
auto ptr_x =
151-
x_data + ni * C * h * w * simd_width +
152-
ci * h * w * simd_width;
170+
#pragma omp parallel for collapse(2)
171+
for (int ni = 0; ni < n; ni++) {
172+
for (int ci = 0; ci < C; ci++) {
173+
auto ptr_x =
174+
x_data + ni * C * h * w * simd_width +
175+
ci * h * w * simd_width;
153176

154-
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
155-
auto ptr_z =
156-
z_data + ni * C * h * w * simd_width +
157-
ci * h * w * simd_width;
177+
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
178+
auto ptr_z =
179+
z_data + ni * C * h * w * simd_width +
180+
ci * h * w * simd_width;
158181

159-
mul_func(ptr_x, ptr_y, ptr_z, h, w);
160-
}
182+
mul_func(ptr_x, ptr_y, ptr_z, h, w);
161183
}
162184
}
163-
164-
z->set_layout(DataLayout::kMKLDNN);
165-
z->set_format(x->format());
166-
} else {
167-
PADDLE_THROW("Not implemented when dims are equal");
168185
}
186+
187+
z->set_layout(DataLayout::kMKLDNN);
188+
z->set_format(x->format());
169189
} else {
170190
// Fallback to naive version:
191+
const bool are_inputs_in_same_format = x->format() == y->format();
192+
const bool is_x_nchw= x->format() == memory::format::nchw;
193+
const bool is_x_nc = x->format() == memory::format::nc;
194+
const bool is_y_nchw= y->format() == memory::format::nchw;
195+
const bool is_y_nc = y->format() == memory::format::nc;
196+
if(!are_inputs_in_same_format) {
197+
using platform::MKLDNNDeviceContext;
198+
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
199+
const auto& mkldnn_engine = dev_ctx.GetEngine();
200+
if(!(is_x_nchw || is_x_nc))
201+
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine, x->dims().size() == 4);
202+
if(!(is_y_nchw || is_y_nc))
203+
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine, y->dims().size() == 4);
204+
}
205+
171206
auto mul_func = [](T a, T b) -> T { return a * b; };
172207

173208
TransformFunctor<decltype(mul_func), T,

python/paddle/fluid/tests/unittests/test_elementwise_mul_mkldnn_op.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_check_grad_ingore_x(self):
4949
def test_check_grad_ingore_y(self):
5050
pass
5151

52-
@unittest.skip("Not implemented yet.")
52+
@unittest.skip("Not implemented yet.") # TODO(mgallus): enable when implemented.
5353
class TestElementwiseMulMKLDNNOp_BroadcastNCHW8c(ElementwiseMulOp):
5454
def init_input_output(self):
5555
x = np.random.rand(1, 8, 2, 2).astype(self.dtype)
@@ -159,8 +159,7 @@ def test_check_grad_ingore_x(self):
159159
def test_check_grad_ingore_y(self):
160160
pass
161161

162-
@unittest.skip("Not implemented yet.")
163-
class TestElementwiseMulMKLDNNOp_FallbackWithReorder(ElementwiseMulOp):
162+
class TestElementwiseMulMKLDNNOp_FallbackWithReorder1(ElementwiseMulOp):
164163
def init_input_output(self):
165164
self.x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
166165
y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
@@ -169,7 +168,7 @@ def init_input_output(self):
169168
self.out = self.x * y
170169

171170
def setUp(self):
172-
super(TestElementwiseMulMKLDNNOp_FallbackNCHW16C, self).setUp()
171+
super(TestElementwiseMulMKLDNNOp_FallbackWithReorder1, self).setUp()
173172
self.attrs["x_data_format"] = "nchw"
174173
self.attrs["y_data_format"] = "nchw16c"
175174

@@ -188,5 +187,60 @@ def test_check_grad_ingore_x(self):
188187
def test_check_grad_ingore_y(self):
189188
pass
190189

190+
class TestElementwiseMulMKLDNNOp_FallbackWithReorder2(ElementwiseMulOp):
191+
def init_input_output(self):
192+
self.y = np.random.rand(1, 16, 2, 2).astype(self.dtype)
193+
x = np.random.rand(1, 16, 2, 2).astype(self.dtype)
194+
self.x = x.transpose(0, 2, 3, 1).reshape(1, 16, 2, 2)
195+
196+
self.out = x * self.y
197+
198+
def setUp(self):
199+
super(TestElementwiseMulMKLDNNOp_FallbackWithReorder2, self).setUp()
200+
self.attrs["x_data_format"] = "nchw16c"
201+
self.attrs["y_data_format"] = "nchw"
202+
203+
def init_kernel_type(self):
204+
self.use_mkldnn = True
205+
206+
def init_axis(self):
207+
self.axis = 0
208+
209+
def test_check_grad_normal(self):
210+
pass
211+
212+
def test_check_grad_ingore_x(self):
213+
pass
214+
215+
def test_check_grad_ingore_y(self):
216+
pass
217+
218+
class TestElementwiseMulMKLDNNOp_FallbackNoReorders2(ElementwiseMulOp):
219+
def init_input_output(self):
220+
self.x = np.random.rand(1, 16).astype(self.dtype)
221+
self.y = np.random.rand(1, 16).astype(self.dtype)
222+
223+
self.out = self.x * self.y
224+
225+
def setUp(self):
226+
super(TestElementwiseMulMKLDNNOp_FallbackNoReorders2, self).setUp()
227+
self.attrs["x_data_format"] = "nc"
228+
self.attrs["y_data_format"] = "nc"
229+
230+
def init_kernel_type(self):
231+
self.use_mkldnn = True
232+
233+
def init_axis(self):
234+
self.axis = 0
235+
236+
def test_check_grad_normal(self):
237+
pass
238+
239+
def test_check_grad_ingore_x(self):
240+
pass
241+
242+
def test_check_grad_ingore_y(self):
243+
pass
244+
191245
if __name__ == '__main__':
192246
unittest.main()

0 commit comments

Comments
 (0)