Skip to content

Commit e6cfdf6

Browse files
authored
Merge pull request #14274 from tensor-tang/fix/jit
fix jit on mac
2 parents 8ac2242 + b81e1b6 commit e6cfdf6

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

paddle/fluid/operators/math/CMakeLists.txt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,12 @@ if(WITH_GPU)
7575
endif()
7676
cc_test(concat_test SRCS concat_test.cc DEPS concat_and_split)
7777
cc_test(cpu_vec_test SRCS cpu_vec_test.cc DEPS blas cpu_info)
78-
cc_library(jit_kernel
79-
SRCS jit_kernel.cc jit_gen.cc jit_code.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc
80-
DEPS cpu_info cblas gflags enforce)
78+
79+
set(JIT_KERNEL_SRCS jit_kernel.cc jit_kernel_blas.cc jit_kernel_exp.cc jit_kernel_rnn.cc jit_kernel_crf_decode.cc)
80+
set(JIT_KERNEL_DEPS cpu_info cblas gflags enforce)
81+
if(WITH_XBYAK)
82+
list(APPEND JIT_KERNEL_SRCS jit_gen.cc jit_code.cc)
83+
list(APPEND JIT_KERNEL_DEPS xbyak)
84+
endif()
85+
cc_library(jit_kernel SRCS ${JIT_KERNEL_SRCS} DEPS ${JIT_KERNEL_DEPS})
8186
cc_test(jit_kernel_test SRCS jit_kernel_test.cc DEPS jit_kernel)

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@ limitations under the License. */
1414

1515
#include "paddle/fluid/operators/math/jit_kernel.h"
1616
#include <string>
17-
#include "paddle/fluid/operators/math/jit_code.h"
1817
#include "paddle/fluid/operators/math/jit_kernel_macro.h"
1918
#include "paddle/fluid/platform/enforce.h"
2019

20+
#ifdef PADDLE_WITH_XBYAK
21+
#include "paddle/fluid/operators/math/jit_code.h"
22+
#endif
23+
2124
#ifdef PADDLE_WITH_MKLML
2225
#include "paddle/fluid/platform/dynload/mklml.h"
2326
#endif
@@ -64,6 +67,7 @@ class VMulKernelImpl : public VMulKernel<T> {
6467
static inline bool useMKL(int d) { return false; }
6568

6669
explicit VMulKernelImpl(int d) : VMulKernel<T>() {
70+
#ifdef PADDLE_WITH_XBYAK
6771
if (useJIT(d)) {
6872
// roughly estimate the size of code
6973
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
@@ -72,6 +76,7 @@ class VMulKernelImpl : public VMulKernel<T> {
7276
jitcode_->getCode<void (*)(const T*, const T*, T*, int)>();
7377
return;
7478
}
79+
#endif
7580
#ifdef PADDLE_WITH_MKLML
7681
if (useMKL(d)) {
7782
this->Compute = VMulMKL<T>;
@@ -81,15 +86,21 @@ class VMulKernelImpl : public VMulKernel<T> {
8186
this->Compute = VMulRefer<T>;
8287
}
8388

89+
#ifdef PADDLE_WITH_XBYAK
90+
8491
private:
8592
std::unique_ptr<gen::VMulJitCode> jitcode_{nullptr};
93+
#endif
8694
};
8795

96+
#ifdef PADDLE_WITH_XBYAK
8897
template <>
8998
bool VMulKernelImpl<float>::useJIT(int d) {
9099
return gen::VMulJitCode::init(d);
91100
}
101+
#endif
92102

103+
#ifdef PADDLE_WITH_MKLML
93104
template <>
94105
bool VMulKernelImpl<float>::useMKL(int d) {
95106
return jit::MayIUse(jit::avx512f) && d > 512;
@@ -99,6 +110,7 @@ template <>
99110
bool VMulKernelImpl<double>::useMKL(int d) {
100111
return true;
101112
}
113+
#endif
102114

103115
REGISTER_JITKERNEL(vmul, VMulKernel);
104116

0 commit comments

Comments
 (0)