@@ -69,7 +69,7 @@ void vrelu_intri8(const int n, const float* x, float* y) {
69
69
70
70
TEST (JitKernel, vrelu) {
71
71
namespace jit = paddle::operators::math::jitkernel;
72
- for (int d : {7 , 8 , 15 , 16 , 30 , 256 , 512 }) {
72
+ for (int d : {3 , 7 , 8 , 15 , 16 , 30 , 256 , 512 }) {
73
73
std::vector<float > x (d);
74
74
std::vector<float > zref (d), ztgt (d);
75
75
RandomVec<float >(d, x.data (), -10 .f , 1 .f );
@@ -159,7 +159,7 @@ void vexp_mkl(const int n, const float* x, float* y) {
159
159
160
160
TEST (JitKernel, vexp) {
161
161
namespace jit = paddle::operators::math::jitkernel;
162
- for (int d : {7 , 8 , 12 , 15 , 16 , 20 , 30 , 128 , 256 }) {
162
+ for (int d : {1 , 3 , 4 , 6 , 7 , 8 , 12 , 15 , 16 , 20 , 30 , 128 , 256 }) {
163
163
std::vector<float > x (d);
164
164
std::vector<float > zref (d), ztgt (d);
165
165
RandomVec<float >(d, x.data (), -2 .f , 2 .f );
@@ -234,7 +234,7 @@ void vsigmoid_better(
234
234
235
235
TEST (JitKernel, vsigmoid) {
236
236
namespace jit = paddle::operators::math::jitkernel;
237
- for (int d : {7 , 8 , 15 , 16 , 30 , 32 , 64 , 100 , 128 , 256 }) {
237
+ for (int d : {1 , 3 , 4 , 6 , 7 , 8 , 15 , 16 , 30 , 32 , 64 , 100 , 128 , 256 }) {
238
238
std::vector<float > x (d);
239
239
std::vector<float > zref (d), ztgt (d);
240
240
RandomVec<float >(d, x.data (), -2 .f , 2 .f );
@@ -298,7 +298,7 @@ void vtanh_better(
298
298
299
299
TEST (JitKernel, vtanh) {
300
300
namespace jit = paddle::operators::math::jitkernel;
301
- for (int d : {7 , 8 , 15 , 16 , 30 , 32 , 64 , 100 , 128 , 256 }) {
301
+ for (int d : {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 15 , 16 , 30 , 32 , 64 , 100 , 128 , 256 }) {
302
302
std::vector<float > x (d);
303
303
std::vector<float > zref (d), ztgt (d);
304
304
RandomVec<float >(d, x.data (), -2 .f , 2 .f );
@@ -389,7 +389,7 @@ void lstm_ctht_better(
389
389
390
390
TEST (JitKernel, lstm) {
391
391
namespace jit = paddle::operators::math::jitkernel;
392
- for (int d : {7 , 8 , 15 , 16 , 30 , 32 , 64 , 100 }) {
392
+ for (int d : {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 15 , 16 , 30 , 32 , 64 , 100 }) {
393
393
int d4 = d * 4 ;
394
394
int d3 = d * 3 ;
395
395
std::vector<float > x (d4), xref (d4);
0 commit comments