diff --git a/src/ppl/kernel/arm_server/common/math_neon.h b/src/ppl/kernel/arm_server/common/math_neon.h index 337d0430..02fbbaa4 100644 --- a/src/ppl/kernel/arm_server/common/math_neon.h +++ b/src/ppl/kernel/arm_server/common/math_neon.h @@ -41,10 +41,10 @@ inline float32x4_t v_exp_f32(const float32x4_t v_src) #else tmp = vrndmq_f32(fx); #endif - //TODO: compare is right? - uint32x4_t mask = vceqq_f32(tmp, fx); - mask = vandq_u32(mask, vcvtq_u32_f32(one)); - fx = vsubq_f32(tmp, vcvtq_f32_u32(mask)); + + float32x4_t mask = vreinterpretq_f32_u32(vcgtq_f32(tmp, fx)); + mask = vreinterpretq_f32_s32(vandq_s32(vreinterpretq_s32_f32(mask), vreinterpretq_s32_f32(one))); + fx = vsubq_f32(tmp, mask); tmp = vmulq_f32(fx, vdupq_n_f32(0.693359375)); float32x4_t z = vmulq_f32(fx, vdupq_n_f32(-2.12194440e-4)); @@ -61,10 +61,10 @@ inline float32x4_t v_exp_f32(const float32x4_t v_src) y = vfma(y, z, x); y = vaddq_f32(y, one); - int32x4_t imm0 = vcvtq_s32_f32(fx); - imm0 = vaddq_s32(imm0, vdupq_n_s32(0x7f)); - imm0 = vqrshlq_s32(imm0, vdupq_n_s32(23)); - float32x4_t pow2n = vcvtq_f32_s32(imm0); + int64x2_t imm0 = vreinterpretq_s64_s32(vcvtq_s32_f32(fx)); + imm0 = vreinterpretq_s64_s32(vaddq_s32(vreinterpretq_s32_s64(imm0), vdupq_n_s32(0x7f))); + imm0 = vreinterpretq_s64_s32(vshlq_s32(vreinterpretq_s32_s64(imm0), vdupq_n_s32(23))); + float32x4_t pow2n = vreinterpretq_f32_s32(vreinterpretq_s32_s64(imm0)); y = vmulq_f32(y, pow2n); return y;