Skip to content

Commit c69c416

Browse files
committed
MKLDNN elementwise_mul: Move Kernel to KernelPool to avoid segfaults
test=develop
1 parent 99e3e36 commit c69c416

File tree

4 files changed

+95
-52
lines changed

4 files changed

+95
-52
lines changed

paddle/fluid/operators/elementwise_mul_mkldnn_op.cc renamed to paddle/fluid/operators/elementwise/elementwise_mul_mkldnn_op.cc

Lines changed: 9 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,61 +13,21 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <mkldnn/include/mkldnn.hpp>
16-
#include "paddle/fluid/operators/elementwise_op.h"
17-
#include "paddle/fluid/operators/elementwise_op_function.h"
16+
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
17+
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
1818

1919
#include "paddle/fluid/platform/mkldnn_helper.h"
2020

21-
#include "xbyak/xbyak.h"
22-
#include "xbyak/xbyak_util.h"
21+
#include "paddle/fluid/operators/math/jit_kernel.h"
22+
#include "xbyak.h"
23+
#include "xbyak_util.h"
2324

2425
namespace paddle {
2526
namespace operators {
2627

2728
using framework::DataLayout;
2829
using mkldnn::memory;
2930

30-
struct vector_mul : public Xbyak::CodeGenerator {
31-
vector_mul() {
32-
// RDI is ptr X
33-
// RSI is ptr Y
34-
// RDX is ptr Z
35-
// RCX is h
36-
// r8 is w
37-
38-
push(rbx);
39-
40-
xor_(rax, rax);
41-
xor_(r10, r10);
42-
vmovups(zmm3, ptr[rsi]);
43-
44-
L("h_loop");
45-
xor_(rbx, rbx);
46-
L("w_loop");
47-
vmovups(zmm2, ptr[rdi + rax]);
48-
vmulps(zmm1, zmm2, zmm3);
49-
vmovups(ptr[rdx + rax], zmm1);
50-
add(rax, 64);
51-
inc(rbx);
52-
cmp(r8, rbx);
53-
jnz("w_loop");
54-
inc(r10);
55-
cmp(r10, rcx);
56-
jnz("h_loop");
57-
58-
pop(rbx);
59-
ret();
60-
}
61-
};
62-
63-
void check(const float* x, const float* y, float* z, int w) {
64-
for (int wi = 0; wi < w; wi++) {
65-
for (int i = 0; i < 16; i++) {
66-
z[wi * 16 + i] = x[wi * 16 + i] * y[i];
67-
}
68-
}
69-
}
70-
7131
static mkldnn::memory::format StringToMKLDNNFormat(std::string& format) {
7232
std::transform(format.begin(), format.end(), format.begin(), ::tolower);
7333

@@ -163,12 +123,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
163123
constexpr int simd_width = 16;
164124
int C = c / simd_width;
165125

166-
vector_mul mul;
167-
168-
using mul_func_t =
169-
void (*)(const float*, const float*, float*, int, int);
170-
171-
mul_func_t mul_func = (mul_func_t)mul.getCode();
126+
const auto& multiply =
127+
math::jitkernel::KernelPool::Instance()
128+
.template Get<math::jitkernel::EltwiseMulnChw16cNCKernel<T>>(n);
172129

173130
#pragma omp parallel for collapse(2)
174131
for (int ni = 0; ni < n; ni++) {
@@ -180,7 +137,7 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
180137
auto ptr_z =
181138
z_data + ni * C * h * w * simd_width + ci * h * w * simd_width;
182139

183-
mul_func(ptr_x, ptr_y, ptr_z, h, w);
140+
multiply->Compute(ptr_x, ptr_y, ptr_z, h, w);
184141
}
185142
}
186143
}

paddle/fluid/operators/math/jit_code.h

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,42 @@ class VActJitCode : public JitCode {
156156
ymm_t ymm_dst = ymm_t(1);
157157
};
158158

159+
#ifdef PADDLE_WITH_MKLDNN
160+
struct EltwiseMulnChw16cNC : public Xbyak::CodeGenerator {
161+
explicit EltwiseMulnChw16cNC(size_t code_size = 256 * 1024)
162+
: Xbyak::CodeGenerator(code_size) {
163+
// RDI is ptr x_input
164+
// RSI is ptr y_input
165+
// RDX is ptr output
166+
// RCX is height
167+
// r8 is width
168+
169+
push(rbx);
170+
171+
xor_(rax, rax);
172+
xor_(r10, r10);
173+
vmovups(zmm3, ptr[rsi]);
174+
175+
L("h_loop");
176+
xor_(rbx, rbx);
177+
L("w_loop");
178+
vmovups(zmm2, ptr[rdi + rax]);
179+
vmulps(zmm1, zmm2, zmm3);
180+
vmovups(ptr[rdx + rax], zmm1);
181+
add(rax, 64);
182+
inc(rbx);
183+
cmp(r8, rbx);
184+
jnz("w_loop");
185+
inc(r10);
186+
cmp(r10, rcx);
187+
jnz("h_loop");
188+
189+
pop(rbx);
190+
ret();
191+
}
192+
};
193+
#endif
194+
159195
} // namespace gen
160196
} // namespace jitkernel
161197
} // namespace math

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,15 @@ class VAddBiasKernel : public Kernel {
9494
void (*Compute)(const T *, const T *, T *, int);
9595
};
9696

97+
#ifdef PADDLE_WITH_MKLDNN
98+
template <typename T>
99+
class EltwiseMulnChw16cNCKernel : public Kernel {
100+
public:
101+
// nChw16c = nChw16c .* NC
102+
void (*Compute)(const float *, const float *, float *, int, int);
103+
};
104+
#endif
105+
97106
template <typename T>
98107
class VActKernel : public Kernel {
99108
public:

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,44 @@ bool VAddKernelImpl<double>::useMKL(int d) {
226226
}
227227
#endif
228228

229+
#ifdef PADDLE_WITH_MKLDNN
230+
/* EltwiseMul for nChw16c & NC inputs JitKernel */
231+
template <typename T>
232+
class EltwiseMulnChw16cNCKernelImpl
233+
: public math::jitkernel::EltwiseMulnChw16cNCKernel<T> {
234+
public:
235+
JITKERNEL_DECLARE_STATIC_FUNC;
236+
explicit EltwiseMulnChw16cNCKernelImpl(int d)
237+
: EltwiseMulnChw16cNCKernel<T>() {
238+
using mul_func_t = void (*)(const float*, const float*, float*, int, int);
239+
#ifdef PADDLE_WITH_XBYAK
240+
if (useJIT(d)) {
241+
// roughly estimate the size of code
242+
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
243+
sz = sz > 4096 ? sz : 4096;
244+
jitcode_.reset(new gen::EltwiseMulnChw16cNC(sz));
245+
this->Compute = (mul_func_t)jitcode_->getCode();
246+
return;
247+
}
248+
#endif
249+
PADDLE_THROW(
250+
"This kernel shouldn't be used in Non-Xbyak, Non-MKL-DNN "
251+
"environemnt");
252+
}
253+
254+
#ifdef PADDLE_WITH_XBYAK
255+
256+
private:
257+
std::unique_ptr<gen::EltwiseMulnChw16cNC> jitcode_{nullptr};
258+
};
259+
260+
template <>
261+
bool EltwiseMulnChw16cNCKernelImpl<float>::useJIT(int d) {
262+
return true;
263+
}
264+
#endif
265+
#endif
266+
229267
/* VAddRelu JitKernel */
230268
template <typename T>
231269
class VAddReluKernelImpl : public VAddReluKernel<T> {
@@ -394,6 +432,9 @@ REGISTER_JITKERNEL(vscal, VScalKernel);
394432
REGISTER_JITKERNEL(vaddbias, VAddBiasKernel);
395433
REGISTER_JITKERNEL(vrelu, VReluKernel);
396434
REGISTER_JITKERNEL(videntity, VIdentityKernel);
435+
#ifdef PADDLE_WITH_MKLDNN
436+
REGISTER_JITKERNEL(eltwise_mul_nchw16c, EltwiseMulnChw16cNCKernel);
437+
#endif
397438

398439
} // namespace jitkernel
399440
} // namespace math

0 commit comments

Comments
 (0)