Skip to content

Commit 1f00723

Browse files
committed
exp, sigmoid, tanh jitcode support more size
test=develop
1 parent 8cda7b3 commit 1f00723

File tree

7 files changed

+74
-72
lines changed

7 files changed

+74
-72
lines changed

paddle/fluid/operators/math/cpu_vec.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ namespace math {
3333
#define SIGMOID_THRESHOLD_MIN -40.0
3434
#define SIGMOID_THRESHOLD_MAX 13.0
3535

36-
#define AVX_FLOAT_BLOCK 8
36+
#define YMM_FLOAT_BLOCK 8
3737
#define AVX_DOUBLE_BLOCK 4
38-
#define AVX2_FLOAT_BLOCK 8
38+
#define YMM_FLOAT_BLOCK 8
3939
#define AVX2_DOUBLE_BLOCK 4
40-
#define AVX512_FLOAT_BLOCK 16
40+
#define ZMM_FLOAT_BLOCK 16
4141
#define AVX512_DOUBLE_BLOCK 8
4242

4343
template <typename T>
@@ -88,7 +88,7 @@ template <>
8888
inline void vec_scal<float, platform::jit::avx>(const int n, const float a,
8989
const float* x, float* y) {
9090
#ifdef __AVX__
91-
constexpr int block = AVX_FLOAT_BLOCK;
91+
constexpr int block = YMM_FLOAT_BLOCK;
9292
if (n < block) {
9393
vec_scal<float, platform::jit::isa_any>(n, a, x, y);
9494
return;
@@ -142,7 +142,7 @@ template <>
142142
inline void vec_bias_sub<float, platform::jit::avx>(const int n, const float a,
143143
const float* x, float* y) {
144144
#ifdef __AVX__
145-
constexpr int block = AVX_FLOAT_BLOCK;
145+
constexpr int block = YMM_FLOAT_BLOCK;
146146
if (n < block) {
147147
vec_bias_sub<float, platform::jit::isa_any>(n, a, x, y);
148148
return;
@@ -200,7 +200,7 @@ inline void vec_cross<float, platform::jit::avx>(const int n, const float* x,
200200
const float* y, const float* z,
201201
float* out) {
202202
#ifdef __AVX__
203-
constexpr int block = AVX_FLOAT_BLOCK;
203+
constexpr int block = YMM_FLOAT_BLOCK;
204204
if (n < block) {
205205
vec_cross<float, platform::jit::isa_any>(n, x, y, z, out);
206206
return;
@@ -257,7 +257,7 @@ template <>
257257
inline void vec_add_bias<float, platform::jit::avx>(const int n, const float a,
258258
const float* x, float* y) {
259259
#ifdef __AVX__
260-
constexpr int block = AVX_FLOAT_BLOCK;
260+
constexpr int block = YMM_FLOAT_BLOCK;
261261
if (n < block) {
262262
vec_add_bias<float, platform::jit::isa_any>(n, a, x, y);
263263
return;
@@ -326,7 +326,7 @@ template <>
326326
inline void vec_sigmoid<float, platform::jit::avx>(const int n, const float* x,
327327
float* y) {
328328
#ifdef __AVX__
329-
constexpr int block = AVX_FLOAT_BLOCK;
329+
constexpr int block = YMM_FLOAT_BLOCK;
330330
if (n < block) {
331331
vec_sigmoid<float, platform::jit::isa_any>(n, x, y);
332332
return;
@@ -415,7 +415,7 @@ template <>
415415
inline void vec_relu<float, platform::jit::avx>(const int n, const float* x,
416416
float* y) {
417417
#ifdef __AVX__
418-
constexpr int block = AVX_FLOAT_BLOCK;
418+
constexpr int block = YMM_FLOAT_BLOCK;
419419
if (n < block * 4) {
420420
vec_relu<float, platform::jit::isa_any>(n, x, y);
421421
return;

paddle/fluid/operators/math/jit_code.cc

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void VXXJitCode::generate() {
4141
} else if (scalar_index_ == 2) {
4242
vbroadcastss(ymm_src2, ptr[param2]);
4343
}
44-
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
44+
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
4545
if (scalar_index_ != 1) {
4646
vmovups(ymm_src1, ptr[param1 + offset]);
4747
}
@@ -57,9 +57,9 @@ void VXXJitCode::generate() {
5757
vmaxps(ymm_dst, ymm_zero, ymm_dst);
5858
}
5959
vmovups(ptr[param3 + offset], ymm_dst);
60-
offset += sizeof(float) * AVX_FLOAT_BLOCK;
60+
offset += sizeof(float) * YMM_FLOAT_BLOCK;
6161
}
62-
int rest = num_ % AVX_FLOAT_BLOCK;
62+
int rest = num_ % YMM_FLOAT_BLOCK;
6363
if (rest >= 4) {
6464
if (scalar_index_ != 1) {
6565
vmovups(xmm_src1, ptr[param1 + offset]);
@@ -133,23 +133,23 @@ void VXXJitCode::generate() {
133133

134134
#define REPEAT_8TIMES(val) val, val, val, val, val, val, val, val
135135

136-
#define OFFSET_EXP_ONE 0 * AVX_FLOAT_BLOCK * sizeof(float)
137-
#define OFFSET_EXP_TWO 1 * AVX_FLOAT_BLOCK * sizeof(float)
138-
#define OFFSET_EXP_0P5 2 * AVX_FLOAT_BLOCK * sizeof(float)
139-
#define OFFSET_EXP_HIG 3 * AVX_FLOAT_BLOCK * sizeof(float)
140-
#define OFFSET_EXP_LOW 4 * AVX_FLOAT_BLOCK * sizeof(float)
141-
#define OFFSET_EXP_LOG2EF 5 * AVX_FLOAT_BLOCK * sizeof(float)
142-
#define OFFSET_EXP_C1 6 * AVX_FLOAT_BLOCK * sizeof(float)
143-
#define OFFSET_EXP_C2 7 * AVX_FLOAT_BLOCK * sizeof(float)
144-
#define OFFSET_EXP_P0 8 * AVX_FLOAT_BLOCK * sizeof(float)
145-
#define OFFSET_EXP_P1 9 * AVX_FLOAT_BLOCK * sizeof(float)
146-
#define OFFSET_EXP_P2 10 * AVX_FLOAT_BLOCK * sizeof(float)
147-
#define OFFSET_EXP_P3 11 * AVX_FLOAT_BLOCK * sizeof(float)
148-
#define OFFSET_EXP_P4 12 * AVX_FLOAT_BLOCK * sizeof(float)
149-
#define OFFSET_EXP_P5 13 * AVX_FLOAT_BLOCK * sizeof(float)
150-
#define OFFSET_EXP_MAX_INPUT 14 * AVX_FLOAT_BLOCK * sizeof(float)
151-
#define OFFSET_SIGMOID_MAX 15 * AVX_FLOAT_BLOCK * sizeof(float)
152-
#define OFFSET_SIGMOID_MIN 16 * AVX_FLOAT_BLOCK * sizeof(float)
136+
#define OFFSET_EXP_ONE 0 * YMM_FLOAT_BLOCK * sizeof(float)
137+
#define OFFSET_EXP_TWO 1 * YMM_FLOAT_BLOCK * sizeof(float)
138+
#define OFFSET_EXP_0P5 2 * YMM_FLOAT_BLOCK * sizeof(float)
139+
#define OFFSET_EXP_HIG 3 * YMM_FLOAT_BLOCK * sizeof(float)
140+
#define OFFSET_EXP_LOW 4 * YMM_FLOAT_BLOCK * sizeof(float)
141+
#define OFFSET_EXP_LOG2EF 5 * YMM_FLOAT_BLOCK * sizeof(float)
142+
#define OFFSET_EXP_C1 6 * YMM_FLOAT_BLOCK * sizeof(float)
143+
#define OFFSET_EXP_C2 7 * YMM_FLOAT_BLOCK * sizeof(float)
144+
#define OFFSET_EXP_P0 8 * YMM_FLOAT_BLOCK * sizeof(float)
145+
#define OFFSET_EXP_P1 9 * YMM_FLOAT_BLOCK * sizeof(float)
146+
#define OFFSET_EXP_P2 10 * YMM_FLOAT_BLOCK * sizeof(float)
147+
#define OFFSET_EXP_P3 11 * YMM_FLOAT_BLOCK * sizeof(float)
148+
#define OFFSET_EXP_P4 12 * YMM_FLOAT_BLOCK * sizeof(float)
149+
#define OFFSET_EXP_P5 13 * YMM_FLOAT_BLOCK * sizeof(float)
150+
#define OFFSET_EXP_MAX_INPUT 14 * YMM_FLOAT_BLOCK * sizeof(float)
151+
#define OFFSET_SIGMOID_MAX 15 * YMM_FLOAT_BLOCK * sizeof(float)
152+
#define OFFSET_SIGMOID_MIN 16 * YMM_FLOAT_BLOCK * sizeof(float)
153153

154154
static const float exp_float_consts[] ALIGN32 = {
155155
REPEAT_8TIMES(1.f),
@@ -177,9 +177,12 @@ bool VActJitCode::init(int d, operand_type type) {
177177
bool ok = MayIUse(avx);
178178
if (type == operand_type::relu) {
179179
return ok;
180+
} else if (type == operand_type::exp) {
181+
// exp is slower than mkl when d >= 256
182+
return ok && d % 8 == 0 && d < 256;
180183
} else {
181184
// TODO(TJ): support more
182-
return ok && d == 8; // only 8 yet
185+
return ok && d % 8 == 0;
183186
}
184187
}
185188

@@ -224,7 +227,7 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
224227
vmovaps(ymm_tmp, ptr[reg_ptr_global + OFFSET_EXP_P0]);
225228
vmulps(ymm_dst, ymm_src, ymm_tmp);
226229
for (size_t i = OFFSET_EXP_P1; i < OFFSET_EXP_P5;
227-
i += (AVX_FLOAT_BLOCK * sizeof(float))) {
230+
i += (YMM_FLOAT_BLOCK * sizeof(float))) {
228231
vmovaps(ymm_tmp, ptr[reg_ptr_global + i]); // P1~P4
229232
vaddps(ymm_dst, ymm_dst, ymm_tmp);
230233
vmulps(ymm_dst, ymm_dst, ymm_src);
@@ -249,15 +252,15 @@ void VActJitCode::exp_ymm(ymm_t& ymm_dst, ymm_t& ymm_src, int fx_idx,
249252
reg64_t reg_ptr_tmp = reg_ptr_global;
250253
mov(reg_ptr_tmp, reinterpret_cast<size_t>(g_tmp_mem));
251254
vmovdqa(ptr[reg_ptr_tmp], ymm_int);
252-
vmovdqa(ptr[reg_ptr_tmp + AVX_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
255+
vmovdqa(ptr[reg_ptr_tmp + YMM_FLOAT_BLOCK * sizeof(float)], ymm_tmp);
253256
vpaddd(xtmp1, xtmp1, xtmp2);
254257
vpslld(xtmp1, xtmp1, 23);
255258
vmovdqa(ptr[reg_ptr_tmp], xtmp1);
256259
// next 128bits
257260
vmovdqa(xtmp1, ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)]);
258261
vmovdqa(xtmp2,
259262
ptr[reg_ptr_tmp +
260-
(AVX_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
263+
(YMM_FLOAT_BLOCK + 4 /*xmm float block*/) * sizeof(float)]);
261264
vpaddd(xtmp1, xtmp1, xtmp2);
262265
vpslld(xtmp1, xtmp1, 23);
263266
vmovdqa(ptr[reg_ptr_tmp + 4 /*xmm float block*/ * sizeof(float)], xtmp1);
@@ -317,7 +320,7 @@ void VActJitCode::generate() {
317320
vxorps(ymm_zero, ymm_zero, ymm_zero);
318321
}
319322
int offset = 0;
320-
for (int i = 0; i < num_ / AVX_FLOAT_BLOCK; ++i) {
323+
for (int i = 0; i < num_ / YMM_FLOAT_BLOCK; ++i) {
321324
vmovups(ymm_src, ptr[param1 + offset]);
322325
switch (type_) {
323326
case operand_type::relu:
@@ -338,14 +341,14 @@ void VActJitCode::generate() {
338341
break;
339342
}
340343
vmovups(ptr[param2 + offset], ymm_dst);
341-
offset += sizeof(float) * AVX_FLOAT_BLOCK;
344+
offset += sizeof(float) * YMM_FLOAT_BLOCK;
342345
}
343346
if (type_ != operand_type::relu) {
344347
// TODO(TJ): remove me
345348
ret();
346349
return;
347350
}
348-
int rest = num_ % AVX_FLOAT_BLOCK;
351+
int rest = num_ % YMM_FLOAT_BLOCK;
349352
if (rest >= 4) {
350353
vmovups(xmm_src, ptr[param1 + offset]);
351354
vmaxps(xmm_dst, xmm_zero, xmm_src);

paddle/fluid/operators/math/jit_kernel.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,9 @@ namespace jitkernel {
2929
#define SIGMOID_THRESHOLD_MIN -40.0
3030
#define SIGMOID_THRESHOLD_MAX 13.0
3131
#define EXP_MAX_INPUT 40.0
32-
// TODO(TJ): change AVX_FLOAT_BLOCK to YMM_FLOAT_BLOCK
33-
#define AVX_FLOAT_BLOCK 8
34-
#define AVX2_FLOAT_BLOCK 8
35-
#define AVX512_FLOAT_BLOCK 16
32+
#define XMM_FLOAT_BLOCK 4
33+
#define YMM_FLOAT_BLOCK 8
34+
#define ZMM_FLOAT_BLOCK 16
3635

3736
typedef enum { kLT8, kEQ8, kGT8LT16, kEQ16, kGT16 } jit_block;
3837

paddle/fluid/operators/math/jit_kernel_blas.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class VMulKernelImpl : public VMulKernel<T> {
133133
#ifdef PADDLE_WITH_XBYAK
134134
if (useJIT(d)) {
135135
// roughly estimate the size of code
136-
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
136+
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
137137
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 0, false,
138138
sz > 4096 ? sz : 4096));
139139
this->Compute =
@@ -184,7 +184,7 @@ class VAddKernelImpl : public VAddKernel<T> {
184184
explicit VAddKernelImpl(int d) : VAddKernel<T>() {
185185
#ifdef PADDLE_WITH_XBYAK
186186
if (useJIT(d)) {
187-
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
187+
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
188188
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, false,
189189
sz > 4096 ? sz : 4096));
190190
this->Compute =
@@ -234,7 +234,7 @@ class VAddReluKernelImpl : public VAddReluKernel<T> {
234234
explicit VAddReluKernelImpl(int d) : VAddReluKernel<T>() {
235235
#ifdef PADDLE_WITH_XBYAK
236236
if (useJIT(d)) {
237-
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
237+
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
238238
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 0, true,
239239
sz > 4096 ? sz : 4096));
240240
this->Compute =
@@ -266,7 +266,7 @@ class VScalKernelImpl : public VScalKernel<T> {
266266
explicit VScalKernelImpl(int d) : VScalKernel<T>() {
267267
#ifdef PADDLE_WITH_XBYAK
268268
if (useJIT(d)) {
269-
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
269+
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
270270
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::mul, 1, false,
271271
sz > 4096 ? sz : 4096));
272272
this->Compute =
@@ -315,7 +315,7 @@ class VAddBiasKernelImpl : public VAddBiasKernel<T> {
315315
explicit VAddBiasKernelImpl(int d) : VAddBiasKernel<T>() {
316316
#ifdef PADDLE_WITH_XBYAK
317317
if (useJIT(d)) {
318-
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8;
318+
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 4 * 8;
319319
jitcode_.reset(new gen::VXXJitCode(d, gen::operand_type::add, 1, false,
320320
sz > 4096 ? sz : 4096));
321321
this->Compute =
@@ -349,7 +349,7 @@ class VReluKernelImpl : public VReluKernel<T> {
349349
#ifdef PADDLE_WITH_XBYAK
350350
if (useJIT(d)) {
351351
size_t sz = 96 /* init size */ +
352-
d / AVX_FLOAT_BLOCK * 4 /* instructions */ *
352+
d / YMM_FLOAT_BLOCK * 4 /* instructions */ *
353353
8 /* average bytes for each instruction */;
354354
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::relu,
355355
sz > 4096 ? sz : 4096));

paddle/fluid/operators/math/jit_kernel_crf_decode.cc

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
105105
int tag_num) \
106106
: CRFDecodeKernel<float>() { \
107107
this->num_ = tag_num; \
108-
this->end_ = this->num_ / AVX_FLOAT_BLOCK; \
109-
this->rest_ = this->num_ % AVX_FLOAT_BLOCK; \
108+
this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
109+
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
110110
} \
111111
template <> \
112112
void CRFDecodeKernelImpl<float, jit::avx, block>::Compute( \
113113
const int seq_len, const float* x, const float* w, float* alpha, \
114114
int* track) const { \
115-
INIT_ALPHA(AVX_FLOAT_BLOCK) \
115+
INIT_ALPHA(YMM_FLOAT_BLOCK) \
116116
/* Use the column-major strategy to get the location of maximum score.*/ \
117117
int seq_offset = 0; \
118118
constexpr int state_trans_base_idx = 2; \
@@ -150,7 +150,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
150150
max_score = _mm256_max_ps(max_score, score_v); \
151151
trans_offset += this->num_; \
152152
} \
153-
UPDATE_ALPHA(AVX_FLOAT_BLOCK) \
153+
UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
154154
} \
155155
seq_offset += this->num_; \
156156
} \
@@ -161,14 +161,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
161161
CRFDecodeKernelImpl<float, isa, block>::CRFDecodeKernelImpl(int tag_num) \
162162
: CRFDecodeKernel<float>() { \
163163
this->num_ = tag_num; \
164-
this->end_ = this->num_ / AVX2_FLOAT_BLOCK; \
165-
this->rest_ = this->num_ % AVX2_FLOAT_BLOCK; \
164+
this->end_ = this->num_ / YMM_FLOAT_BLOCK; \
165+
this->rest_ = this->num_ % YMM_FLOAT_BLOCK; \
166166
} \
167167
template <> \
168168
void CRFDecodeKernelImpl<float, isa, block>::Compute( \
169169
const int seq_len, const float* x, const float* w, float* alpha, \
170170
int* track) const { \
171-
INIT_ALPHA(AVX2_FLOAT_BLOCK) \
171+
INIT_ALPHA(YMM_FLOAT_BLOCK) \
172172
/* Use the column-major strategy to get the location of maximum score.*/ \
173173
int seq_offset = 0; \
174174
constexpr int state_trans_base_idx = 2; \
@@ -196,7 +196,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
196196
max_score = _mm256_max_ps(max_score, score_v); \
197197
trans_offset += this->num_; \
198198
} \
199-
UPDATE_ALPHA(AVX2_FLOAT_BLOCK) \
199+
UPDATE_ALPHA(YMM_FLOAT_BLOCK) \
200200
} \
201201
seq_offset += this->num_; \
202202
} \
@@ -208,14 +208,14 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
208208
int tag_num) \
209209
: CRFDecodeKernel<float>() { \
210210
this->num_ = tag_num; \
211-
this->end_ = this->num_ / AVX512_FLOAT_BLOCK; \
212-
this->rest_ = this->num_ % AVX512_FLOAT_BLOCK; \
211+
this->end_ = this->num_ / ZMM_FLOAT_BLOCK; \
212+
this->rest_ = this->num_ % ZMM_FLOAT_BLOCK; \
213213
} \
214214
template <> \
215215
void CRFDecodeKernelImpl<float, jit::avx512f, block>::Compute( \
216216
const int seq_len, const float* x, const float* w, float* alpha, \
217217
int* track) const { \
218-
INIT_ALPHA(AVX512_FLOAT_BLOCK) \
218+
INIT_ALPHA(ZMM_FLOAT_BLOCK) \
219219
/* Use the column-major strategy to get the location of maximum score.*/ \
220220
int seq_offset = 0; \
221221
constexpr int state_trans_base_idx = 2; \
@@ -250,7 +250,7 @@ class CRFDecodeKernelImpl : public CRFDecodeKernel<T> {
250250
this->num_ + j_offset), \
251251
max_j); \
252252
/* Calculate the offset of next step*/ \
253-
j_offset += AVX512_FLOAT_BLOCK; \
253+
j_offset += ZMM_FLOAT_BLOCK; \
254254
if (j == this->end_ - 1) { \
255255
if (this->rest_ > 0) { \
256256
j_offset += last_offset; \

paddle/fluid/operators/math/jit_kernel_exp.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class VExpKernelImpl : public VExpKernel<T> {
116116
explicit VExpKernelImpl(int d) : VExpKernel<T>() {
117117
#ifdef PADDLE_WITH_XBYAK
118118
if (useJIT(d)) {
119-
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
119+
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 70 * 8;
120120
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::exp,
121121
sz > 4096 ? sz : 4096));
122122
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
@@ -167,7 +167,7 @@ class VSigmoidKernelImpl : public VSigmoidKernel<T> {
167167
explicit VSigmoidKernelImpl(int d) : VSigmoidKernel<T>() {
168168
#ifdef PADDLE_WITH_XBYAK
169169
if (useJIT(d)) {
170-
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
170+
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 82 * 8;
171171
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::sigmoid,
172172
sz > 4096 ? sz : 4096));
173173
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();
@@ -219,7 +219,7 @@ class VTanhKernelImpl : public VTanhKernel<T> {
219219
explicit VTanhKernelImpl(int d) : VTanhKernel<T>() {
220220
#ifdef PADDLE_WITH_XBYAK
221221
if (useJIT(d)) {
222-
size_t sz = 96 + d / AVX_FLOAT_BLOCK * 4 * 8; // should change
222+
size_t sz = 96 + d / YMM_FLOAT_BLOCK * 84 * 8;
223223
jitcode_.reset(new gen::VActJitCode(d, gen::operand_type::tanh,
224224
sz > 4096 ? sz : 4096));
225225
this->Compute = jitcode_->getCode<void (*)(const T*, T*, int)>();

0 commit comments

Comments
 (0)