Skip to content

Commit 4e67fe6

Browse files
committed
refine act and vxx with all size
1 parent ba3eaed commit 4e67fe6

File tree

1 file changed

+60
-87
lines changed

1 file changed

+60
-87
lines changed

paddle/fluid/operators/math/jit_code.cc

Lines changed: 60 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -60,60 +60,53 @@ void VXXJitCode::generate() {
6060
offset += sizeof(float) * YMM_FLOAT_BLOCK;
6161
}
6262
int rest = num_ % YMM_FLOAT_BLOCK;
63-
if (rest >= 4) {
64-
if (scalar_index_ != 1) {
65-
vmovups(xmm_src1, ptr[param1 + offset]);
66-
}
67-
if (scalar_index_ != 2) {
68-
vmovups(xmm_src2, ptr[param2 + offset]);
69-
}
70-
if (type_ == operand_type::mul) {
71-
vmulps(xmm_dst, xmm_src1, xmm_src2);
72-
} else if (type_ == operand_type::add) {
73-
vaddps(xmm_dst, xmm_src1, xmm_src2);
74-
}
75-
if (with_relu_) {
76-
vmaxps(xmm_dst, xmm_zero, xmm_dst);
77-
}
78-
vmovups(ptr[param3 + offset], xmm_dst);
79-
offset += sizeof(float) * 4;
80-
rest -= 4;
81-
}
82-
if (rest >= 2) {
83-
if (scalar_index_ != 1) {
84-
vmovq(xmm_src1, ptr[param1 + offset]);
85-
}
86-
if (scalar_index_ != 2) {
87-
vmovq(xmm_src2, ptr[param2 + offset]);
63+
int block = XMM_FLOAT_BLOCK;
64+
while (rest > 0) {
65+
if (rest >= 4) {
66+
if (scalar_index_ != 1) {
67+
vmovups(xmm_src1, ptr[param1 + offset]);
68+
}
69+
if (scalar_index_ != 2) {
70+
vmovups(xmm_src2, ptr[param2 + offset]);
71+
}
72+
} else if (rest >= 2) {
73+
if (scalar_index_ != 1) {
74+
vmovq(xmm_src1, ptr[param1 + offset]);
75+
}
76+
if (scalar_index_ != 2) {
77+
vmovq(xmm_src2, ptr[param2 + offset]);
78+
}
79+
} else {
80+
if (scalar_index_ != 1) {
81+
vmovss(xmm_src1, ptr[param1 + offset]);
82+
}
83+
if (scalar_index_ != 2) {
84+
vmovss(xmm_src2, ptr[param2 + offset]);
85+
}
8886
}
89-
if (type_ == operand_type::mul) {
90-
vmulps(xmm_dst, xmm_src1, xmm_src2);
91-
} else if (type_ == operand_type::add) {
92-
vaddps(xmm_dst, xmm_src1, xmm_src2);
87+
switch (type_) {
88+
case operand_type::mul:
89+
vmulps(xmm_dst, xmm_src1, xmm_src2);
90+
break;
91+
case operand_type::add:
92+
vaddps(xmm_dst, xmm_src1, xmm_src2);
93+
break;
94+
default:
95+
break;
9396
}
9497
if (with_relu_) {
9598
vmaxps(xmm_dst, xmm_zero, xmm_dst);
9699
}
97-
vmovq(ptr[param3 + offset], xmm_dst);
98-
offset += sizeof(float) * 2;
99-
rest -= 2;
100-
}
101-
if (rest > 0) {
102-
if (scalar_index_ != 1) {
103-
vmovss(xmm_src1, ptr[param1 + offset]);
104-
}
105-
if (scalar_index_ != 2) {
106-
vmovss(xmm_src2, ptr[param2 + offset]);
107-
}
108-
if (type_ == operand_type::mul) {
109-
vmulss(xmm_dst, xmm_src1, xmm_src2);
110-
} else if (type_ == operand_type::add) {
111-
vaddss(xmm_dst, xmm_src1, xmm_src2);
100+
if (rest >= 4) {
101+
vmovups(ptr[param3 + offset], xmm_dst);
102+
} else if (rest >= 2) {
103+
vmovq(ptr[param3 + offset], xmm_dst);
104+
} else {
105+
vmovss(ptr[param3 + offset], xmm_dst);
112106
}
113-
if (with_relu_) {
114-
vmaxps(xmm_dst, xmm_zero, xmm_dst);
115-
}
116-
vmovss(ptr[param3 + offset], xmm_dst);
107+
offset += sizeof(float) * block;
108+
rest -= block;
109+
block /= 2;
117110
}
118111
ret();
119112
}
@@ -175,11 +168,9 @@ static int g_tmp_mem[16] ALIGN32 = {0};
175168

176169
bool VActJitCode::init(int d, operand_type type) {
177170
bool ok = MayIUse(avx);
178-
if (type == operand_type::relu) {
171+
if (type == operand_type::relu || type == operand_type::exp) {
172+
// TODO(TJ): implement avx512, avx_exp is slower than mkl when d >= 256
179173
return ok;
180-
} else if (type == operand_type::exp) {
181-
// exp is slower than mkl when d >= 256
182-
return ok; //&& d % 4 == 0 && d < 256;
183174
} else {
184175
// TODO(TJ): support more
185176
return ok && d % 8 == 0;
@@ -412,24 +403,15 @@ void VActJitCode::generate() {
412403
return;
413404
}
414405
int rest = num_ % YMM_FLOAT_BLOCK;
415-
if (rest >= 4) {
416-
vmovups(xmm_src, ptr[param1 + offset]);
417-
switch (type_) {
418-
case operand_type::relu:
419-
relu_xmm(xmm_dst, xmm_src, xmm_zero);
420-
break;
421-
case operand_type::exp:
422-
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
423-
break;
424-
default:
425-
break;
406+
int block = XMM_FLOAT_BLOCK;
407+
while (rest > 0) {
408+
if (rest >= 4) {
409+
vmovups(xmm_src, ptr[param1 + offset]);
410+
} else if (rest >= 2) {
411+
vmovq(xmm_src, ptr[param1 + offset]);
412+
} else {
413+
vmovss(xmm_src, ptr[param1 + offset]);
426414
}
427-
vmovups(ptr[param2 + offset], xmm_dst);
428-
offset += sizeof(float) * 4;
429-
rest -= 4;
430-
}
431-
if (rest >= 2) {
432-
vmovq(xmm_src, ptr[param1 + offset]);
433415
switch (type_) {
434416
case operand_type::relu:
435417
relu_xmm(xmm_dst, xmm_src, xmm_zero);
@@ -440,25 +422,16 @@ void VActJitCode::generate() {
440422
default:
441423
break;
442424
}
443-
vmovq(ptr[param2 + offset], xmm_dst);
444-
offset += sizeof(float) * 2;
445-
rest -= 2;
446-
}
447-
if (rest > 0) {
448-
// vmovups();
449-
vmovss(xmm_src, ptr[param1 + offset]);
450-
451-
switch (type_) {
452-
case operand_type::relu:
453-
relu_xmm(xmm_dst, xmm_src, xmm_zero);
454-
break;
455-
case operand_type::exp:
456-
exp_xmm(xmm_dst, xmm_src, 2, 3, 4, 5);
457-
break;
458-
default:
459-
break;
425+
if (rest >= 4) {
426+
vmovups(ptr[param2 + offset], xmm_dst);
427+
} else if (rest >= 2) {
428+
vmovq(ptr[param2 + offset], xmm_dst);
429+
} else {
430+
vmovss(ptr[param2 + offset], xmm_dst);
460431
}
461-
vmovss(ptr[param2 + offset], xmm_dst);
432+
offset += sizeof(float) * block;
433+
rest -= block;
434+
block /= 2;
462435
}
463436
ret();
464437
}

0 commit comments

Comments
 (0)