Skip to content

Commit a19b322

Browse files
committed
fix jitcode small size
test=develop
1 parent 4dbdfa6 commit a19b322

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,26 @@ void VXXJitCode::generate() {
5959
offset += sizeof(float) * YMM_FLOAT_BLOCK;
6060
}
6161
int rest = num_ % YMM_FLOAT_BLOCK;
62-
int block = XMM_FLOAT_BLOCK;
6362
while (rest > 0) {
63+
int block = XMM_FLOAT_BLOCK;
6464
if (rest >= 4) {
65+
block = 4;
6566
if (scalar_index_ != 1) {
6667
vmovups(xmm_src1, ptr[param1 + offset]);
6768
}
6869
if (scalar_index_ != 2) {
6970
vmovups(xmm_src2, ptr[param2 + offset]);
7071
}
7172
} else if (rest >= 2) {
73+
block = 2;
7274
if (scalar_index_ != 1) {
7375
vmovq(xmm_src1, ptr[param1 + offset]);
7476
}
7577
if (scalar_index_ != 2) {
7678
vmovq(xmm_src2, ptr[param2 + offset]);
7779
}
7880
} else {
81+
block = 1;
7982
if (scalar_index_ != 1) {
8083
vmovss(xmm_src1, ptr[param1 + offset]);
8184
}
@@ -105,7 +108,6 @@ void VXXJitCode::generate() {
105108
}
106109
offset += sizeof(float) * block;
107110
rest -= block;
108-
block /= 2;
109111
}
110112
ret();
111113
}
@@ -167,13 +169,16 @@ void VActJitCode::generate() {
167169
offset += sizeof(float) * YMM_FLOAT_BLOCK;
168170
}
169171
int rest = num_ % YMM_FLOAT_BLOCK;
170-
int block = XMM_FLOAT_BLOCK;
171172
while (rest > 0) {
173+
int block = XMM_FLOAT_BLOCK;
172174
if (rest >= 4) {
175+
block = 4;
173176
vmovups(xmm_src, ptr[param1 + offset]);
174177
} else if (rest >= 2) {
178+
block = 2;
175179
vmovq(xmm_src, ptr[param1 + offset]);
176180
} else {
181+
block = 1;
177182
vmovss(xmm_src, ptr[param1 + offset]);
178183
}
179184
switch (type_) {
@@ -201,7 +206,6 @@ void VActJitCode::generate() {
201206
}
202207
offset += sizeof(float) * block;
203208
rest -= block;
204-
block /= 2;
205209
}
206210
ret();
207211
}

paddle/fluid/operators/math/jit_kernel_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ void vrelu_intri8(const int n, const float* x, float* y) {
6969

7070
TEST(JitKernel, vrelu) {
7171
namespace jit = paddle::operators::math::jitkernel;
72-
for (int d : {7, 8, 15, 16, 30, 256, 512}) {
72+
for (int d : {3, 7, 8, 15, 16, 30, 256, 512}) {
7373
std::vector<float> x(d);
7474
std::vector<float> zref(d), ztgt(d);
7575
RandomVec<float>(d, x.data(), -10.f, 1.f);
@@ -159,7 +159,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
159159

160160
TEST(JitKernel, vexp) {
161161
namespace jit = paddle::operators::math::jitkernel;
162-
for (int d : {7, 8, 12, 15, 16, 20, 30, 128, 256}) {
162+
for (int d : {1, 3, 4, 6, 7, 8, 12, 15, 16, 20, 30, 128, 256}) {
163163
std::vector<float> x(d);
164164
std::vector<float> zref(d), ztgt(d);
165165
RandomVec<float>(d, x.data(), -2.f, 2.f);
@@ -234,7 +234,7 @@ void vsigmoid_better(
234234

235235
TEST(JitKernel, vsigmoid) {
236236
namespace jit = paddle::operators::math::jitkernel;
237-
for (int d : {7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
237+
for (int d : {1, 3, 4, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
238238
std::vector<float> x(d);
239239
std::vector<float> zref(d), ztgt(d);
240240
RandomVec<float>(d, x.data(), -2.f, 2.f);
@@ -298,7 +298,7 @@ void vtanh_better(
298298

299299
TEST(JitKernel, vtanh) {
300300
namespace jit = paddle::operators::math::jitkernel;
301-
for (int d : {7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
301+
for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100, 128, 256}) {
302302
std::vector<float> x(d);
303303
std::vector<float> zref(d), ztgt(d);
304304
RandomVec<float>(d, x.data(), -2.f, 2.f);
@@ -389,7 +389,7 @@ void lstm_ctht_better(
389389

390390
TEST(JitKernel, lstm) {
391391
namespace jit = paddle::operators::math::jitkernel;
392-
for (int d : {7, 8, 15, 16, 30, 32, 64, 100}) {
392+
for (int d : {1, 2, 3, 4, 5, 6, 7, 8, 15, 16, 30, 32, 64, 100}) {
393393
int d4 = d * 4;
394394
int d3 = d * 3;
395395
std::vector<float> x(d4), xref(d4);

0 commit comments

Comments
 (0)