Skip to content

Commit 10fb4ce

Browse files
authored
Merge pull request #14351 from tpatejko/tpatejko/mkldnn-elementwise_mul
[MKLDNN][JIT][AVX512] Elementwise Mul
2 parents 0528067 + def272c commit 10fb4ce

File tree

10 files changed

+600
-14
lines changed

10 files changed

+600
-14
lines changed

AUTHORS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
| QiJune | Jun Qi |
4343
| qingqing01 | Qing-Qing Dang |
4444
| reyoung | Yang Yu |
45+
| Sand3r- | Michal Gallus |
4546
| Superjom | Chun-Wei Yan |
4647
| tensor-tang | Jian Tang |
4748
| tianbingsz | Tian-Bing Xu |

paddle/fluid/framework/operator.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ class OperatorBase {
100100

101101
const std::string& Type() const { return type_; }
102102

103+
bool HasAttr(const std::string& name) const { return attrs_.count(name); }
103104
template <typename T>
104105
inline const T& Attr(const std::string& name) const {
105106
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <mkldnn/include/mkldnn.hpp>
16+
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
17+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
18+
19+
#include "paddle/fluid/platform/mkldnn_helper.h"
20+
21+
#include "paddle/fluid/operators/math/jit_kernel.h"
22+
#include "xbyak.h"
23+
#include "xbyak_util.h"
24+
25+
namespace paddle {
26+
namespace operators {
27+
28+
using framework::DataLayout;
29+
using mkldnn::memory;
30+
31+
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
32+
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
33+
34+
if (!format.compare("nchw")) {
35+
return memory::format::nchw;
36+
} else if (!format.compare("nchw16c")) {
37+
return memory::format::nChw16c;
38+
} else if (!format.compare("nchw8c")) {
39+
return memory::format::nChw8c;
40+
} else if (!format.compare("nhwc")) {
41+
return memory::format::nhwc;
42+
} else {
43+
return memory::format::any;
44+
}
45+
}
46+
47+
static void UpdateDataFormat(const framework::ExecutionContext& ctx,
48+
framework::Tensor* tensor, const char* attribute) {
49+
if (ctx.op().HasAttr(attribute)) {
50+
auto format_as_string = ctx.Attr<std::string>(attribute);
51+
auto format = StringToMKLDNNFormat(format_as_string);
52+
if (format != memory::format::any) {
53+
tensor->set_format(format);
54+
}
55+
}
56+
}
57+
58+
template <typename T>
59+
static void ReorderInput(framework::Tensor* tensor,
60+
const platform::Place& place,
61+
const mkldnn::engine& engine, bool isFourDim) {
62+
using platform::to_void_cast;
63+
auto dims = paddle::framework::vectorize2int(tensor->dims());
64+
framework::Tensor out_tensor;
65+
out_tensor.Resize(tensor->dims());
66+
out_tensor.set_format(isFourDim ? memory::format::nchw : memory::format::nc);
67+
out_tensor.set_layout(tensor->layout());
68+
mkldnn::memory input_memory = {
69+
{{dims, platform::MKLDNNGetDataType<T>(), tensor->format()}, engine},
70+
to_void_cast<T>(tensor->data<T>())};
71+
mkldnn::memory output_memory = {
72+
{{dims, platform::MKLDNNGetDataType<T>(), out_tensor.format()}, engine},
73+
to_void_cast<T>(out_tensor.mutable_data<T>(place))};
74+
platform::Reorder(input_memory, output_memory);
75+
tensor->ShareDataWith(out_tensor);
76+
}
77+
78+
template <typename T>
79+
class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
80+
public:
81+
void Compute(const framework::ExecutionContext& ctx) const override {
82+
using Tensor = framework::Tensor;
83+
84+
int axis = ctx.Attr<int>("axis");
85+
auto* x = ctx.Input<Tensor>("X");
86+
auto* y = ctx.Input<Tensor>("Y");
87+
auto* z = ctx.Output<Tensor>("Out");
88+
const T* x_data = x->data<T>();
89+
const T* y_data = y->data<T>();
90+
T* z_data = z->mutable_data<T>(ctx.GetPlace());
91+
92+
auto x_dims = x->dims();
93+
auto y_dims_untrimmed = y->dims();
94+
auto x_int_dims = paddle::framework::vectorize2int(x_dims);
95+
96+
UpdateDataFormat(ctx, (Tensor*)x, "x_data_format");
97+
UpdateDataFormat(ctx, (Tensor*)y, "y_data_format");
98+
99+
Xbyak::util::Cpu cpu;
100+
const bool is_avx512_enabled = cpu.has(Xbyak::util::Cpu::tAVX512F);
101+
const bool are_dims_divisable = !(x_int_dims[1] % 16);
102+
const bool is_x_format_correct = x->format() == memory::format::nChw16c;
103+
const bool is_y_format_correct = y->format() == memory::format::nc;
104+
if (is_x_format_correct && is_y_format_correct && are_dims_divisable &&
105+
is_avx512_enabled) {
106+
int pre, n, post;
107+
get_mid_dims(x_dims, y_dims_untrimmed, axis, &pre, &n, &post);
108+
109+
if (post == 1) {
110+
PADDLE_THROW("Not implemented when post is 1");
111+
} else {
112+
// Just check whether it works for RE-Resnext.
113+
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "X should have 4 dimensions");
114+
115+
int n = x_dims[0];
116+
int c = x_dims[1];
117+
int h = x_dims[2];
118+
int w = x_dims[3];
119+
120+
PADDLE_ENFORCE(y_dims_untrimmed[0] == n && y_dims_untrimmed[1] == c,
121+
"Y should be in nc format");
122+
123+
constexpr int simd_width = 16;
124+
int C = c / simd_width;
125+
126+
const auto& multiply =
127+
math::jitkernel::KernelPool::Instance()
128+
.template Get<math::jitkernel::EltwiseMulnChw16cNCKernel<T>>(n);
129+
130+
#pragma omp parallel for collapse(2)
131+
for (int ni = 0; ni < n; ni++) {
132+
for (int ci = 0; ci < C; ci++) {
133+
auto ptr_x =
134+
x_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
135+
136+
auto ptr_y = y_data + ni * C * simd_width + ci * simd_width;
137+
auto ptr_z =
138+
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
139+
140+
multiply->Compute(ptr_x, ptr_y, ptr_z, h, w);
141+
}
142+
}
143+
}
144+
145+
z->set_layout(DataLayout::kMKLDNN);
146+
z->set_format(x->format());
147+
} else {
148+
// Fallback to naive version:
149+
const bool are_inputs_in_same_format = x->format() == y->format();
150+
const bool is_x_nchw = x->format() == memory::format::nchw;
151+
const bool is_x_nc = x->format() == memory::format::nc;
152+
const bool is_y_nchw = y->format() == memory::format::nchw;
153+
const bool is_y_nc = y->format() == memory::format::nc;
154+
if (!are_inputs_in_same_format) {
155+
using platform::MKLDNNDeviceContext;
156+
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
157+
const auto& mkldnn_engine = dev_ctx.GetEngine();
158+
if (!(is_x_nchw || is_x_nc))
159+
ReorderInput<T>((Tensor*)x, ctx.GetPlace(), mkldnn_engine,
160+
x->dims().size() == 4);
161+
if (!(is_y_nchw || is_y_nc))
162+
ReorderInput<T>((Tensor*)y, ctx.GetPlace(), mkldnn_engine,
163+
y->dims().size() == 4);
164+
}
165+
166+
auto mul_func = [](T a, T b) -> T { return a * b; };
167+
168+
TransformFunctor<decltype(mul_func), T,
169+
paddle::platform::CPUDeviceContext, T>
170+
functor(
171+
x, y, z,
172+
ctx.template device_context<paddle::platform::CPUDeviceContext>(),
173+
mul_func);
174+
175+
axis = (axis == -1 ? x_dims.size() - y_dims_untrimmed.size() : axis);
176+
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
177+
"Axis should be in range [0, x_dims)");
178+
179+
auto y_dims = trim_trailing_singular_dims(y_dims_untrimmed);
180+
axis = (y_dims.size() == 0) ? x_dims.size() : axis;
181+
182+
int pre, n, post;
183+
get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post);
184+
185+
if (post == 1) {
186+
functor.RunRowWise(n, pre);
187+
} else {
188+
functor.RunMidWise(n, pre, post);
189+
}
190+
z->set_layout(DataLayout::kMKLDNN);
191+
z->set_format(x->format());
192+
}
193+
}
194+
};
195+
} // namespace operators
196+
} // namespace paddle
197+
198+
namespace ops = paddle::operators;
199+
200+
REGISTER_OP_KERNEL(elementwise_mul, MKLDNN, ::paddle::platform::CPUPlace,
201+
ops::ElementwiseMulMKLDNNKernel<float>)

paddle/fluid/operators/elementwise/elementwise_op.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,20 @@ class ElementwiseOpMaker : public framework::OpProtoAndCheckerMaker {
9797
.EqualGreaterThan(-1);
9898
AddAttr<bool>("use_mkldnn", "(bool, default false). Used by MKLDNN.")
9999
.SetDefault(false);
100+
AddAttr<std::string>(
101+
"x_data_format",
102+
"(string, default NCHW) Only used in mkldnn"
103+
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
104+
"Defaults to \"\". Specify the data format of the output data, "
105+
"the input will be transformed automatically. ")
106+
.SetDefault("");
107+
AddAttr<std::string>(
108+
"y_data_format",
109+
"(string, default \"\") Only used in mkldnn"
110+
"An optional string from: \"NHWC\", \"NCHW\", \"NCHW16C\", \"NCHW8C\". "
111+
"Defaults to \"\". Specify the data format of the output data, "
112+
"the input will be transformed automatically. ")
113+
.SetDefault("");
100114
AddComment(string::Sprintf(R"DOC(
101115
Elementwise %s Operator
102116

paddle/fluid/operators/math/jit_code.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,42 @@ class VActJitCode : public JitCode {
322322
ymm_t ymm_dst = ymm_t(1);
323323
};
324324

325+
#ifdef PADDLE_WITH_MKLDNN
326+
struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator {
327+
explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024)
328+
: Xbyak::CodeGenerator(code_size) {
329+
// RDI is ptr x_input
330+
// RSI is ptr y_input
331+
// RDX is ptr output
332+
// RCX is height
333+
// r8 is width
334+
335+
push(rbx);
336+
337+
xor_(rax, rax);
338+
xor_(r10, r10);
339+
vmovups(zmm3, ptr[rsi]);
340+
341+
L("h_loop");
342+
xor_(rbx, rbx);
343+
L("w_loop");
344+
vmovups(zmm2, ptr[rdi + rax]);
345+
vmulps(zmm1, zmm2, zmm3);
346+
vmovups(ptr[rdx + rax], zmm1);
347+
add(rax, 64);
348+
inc(rbx);
349+
cmp(r8, rbx);
350+
jnz("w_loop");
351+
inc(r10);
352+
cmp(r10, rcx);
353+
jnz("h_loop");
354+
355+
pop(rbx);
356+
ret();
357+
}
358+
};
359+
#endif
360+
325361
} // namespace gen
326362
} // namespace jitkernel
327363
} // namespace math

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,15 @@ class VAddBiasKernel : public Kernel {
9595
void (*Compute)(const T *, const T *, T *, int);
9696
};
9797

98+
#ifdef PADDLE_WITH_MKLDNN
99+
template <typename T>
100+
class EltwiseMulnChw16cNCKernel : public Kernel {
101+
public:
102+
// nChw16c = nChw16c .* NC
103+
void (*Compute)(const float *, const float *, float *, int, int);
104+
};
105+
#endif
106+
98107
template <typename T>
99108
class VActKernel : public Kernel {
100109
public:

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,44 @@ bool VAddKernelImpl<double>::useMKL(int d) {
226226
}
227227
#endif
228228

229+
#ifdef PADDLE_WITH_MKLDNN
230+
/* EltwiseMul for nChw16c & NC inputs JitKernel */
231+
template <typename T>
232+
class EltwiseMulnChw16cNCKernelImpl
233+
: public math::jitkernel::EltwiseMulnChw16cNCKernel<T> {
234+
public:
235+
JITKERNEL_DECLARE_STATIC_FUNC;
236+
explicit EltwiseMulnChw16cNCKernelImpl(int d)
237+
: EltwiseMulnChw16cNCKernel<T>() {
238+
using mul_func_t = void (*)(const float*, const float*, float*, int, int);
239+
#ifdef PADDLE_WITH_XBYAK
240+
if (useJIT(d)) {
241+
// roughly estimate the size of code
242+
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
243+
sz = sz > 4096 ? sz : 4096;
244+
jitcode_.reset(new gen::EltwiseMulnChw16cNC(sz));
245+
this->Compute = (mul_func_t)jitcode_->getCode();
246+
return;
247+
}
248+
#endif
249+
PADDLE_THROW(
250+
"This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN "
251+
"environemnt");
252+
}
253+
254+
#ifdef PADDLE_WITH_XBYAK
255+
256+
private:
257+
std::unique_ptr<gen::EltwiseMulnChw16cNC> jitcode_{nullptr};
258+
};
259+
260+
template <>
261+
bool EltwiseMulnChw16cNCKernelImpl<float>::useJIT(int d) {
262+
return true;
263+
}
264+
#endif
265+
#endif
266+
229267
/* VAddRelu JitKernel */
230268
template <typename T>
231269
class VAddReluKernelImpl : public VAddReluKernel<T> {
@@ -394,6 +432,9 @@ REGISTER_JITKERNEL(vscal, VScalKernel);
394432
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
395433
REGISTER_JITKERNEL(vrelu, VReluKernel);
396434
REGISTER_JITKERNEL(videntity, VIdentityKernel);
435+
#ifdef PADDLE_WITH_MKLDNN
436+
REGISTER_JITKERNEL(eltwise_mul_nchw16c, EltwiseMulnChw16cNCKernel);
437+
#endif
397438

398439
} // namespace jitkernel
399440
} // namespace math

python/paddle/fluid/tests/unittests/op_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,9 @@ def _get_places(self):
362362
else:
363363
return []
364364
places = [fluid.CPUPlace()]
365-
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type):
365+
cpu_only = self._cpu_only if hasattr(self, '_cpu_only') else False
366+
if core.is_compiled_with_cuda() and core.op_support_gpu(self.op_type)\
367+
and not cpu_only:
366368
places.append(core.CUDAPlace(0))
367369
return places
368370

0 commit comments

Comments
 (0)