diff --git a/ynnpack/base/simd/arm_neon.h b/ynnpack/base/simd/arm_neon.h index 405365152c0..e9a3f598d5e 100644 --- a/ynnpack/base/simd/arm_neon.h +++ b/ynnpack/base/simd/arm_neon.h @@ -3,8 +3,8 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#ifndef XNNPACK_YNNPACK_BASE_SIMD_ARM_H_ -#define XNNPACK_YNNPACK_BASE_SIMD_ARM_H_ +#ifndef XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_ +#define XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_ #include @@ -18,6 +18,7 @@ #include "ynnpack/base/base.h" #include "ynnpack/base/bfloat16.h" #include "ynnpack/base/half.h" +#include "ynnpack/base/simd/multi_vec.h" #include "ynnpack/base/simd/vec.h" namespace ynn { @@ -467,6 +468,14 @@ YNN_ALWAYS_INLINE s8x16 max(s8x16 a, s8x16 b) { return s8x16{vmaxq_s8(a.v, b.v)}; } +using f32x4x2 = multi_vec; +YNN_ALWAYS_INLINE f32x4x2 convert(bf16x8 a, float) { + f32x4x2 result; + result.v[0].v = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a.v), 16)); + result.v[1].v = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a.v), 16)); + return result; +} + #ifdef YNN_ARCH_ARM32 YNN_ALWAYS_INLINE float vmaxvq_f32(float32x4_t a) { float32x2_t max_halves = vmax_f32(vget_low_f32(a), vget_high_f32(a)); @@ -562,4 +571,4 @@ YNN_ALWAYS_INLINE std::array, 4> transpose( } // namespace ynn -#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_H_ +#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_ diff --git a/ynnpack/base/simd/test/arm_neon.cc b/ynnpack/base/simd/test/arm_neon.cc index 68e2743f3e7..2d8c60c793d 100644 --- a/ynnpack/base/simd/test/arm_neon.cc +++ b/ynnpack/base/simd/test/arm_neon.cc @@ -69,6 +69,8 @@ TEST_MAX(arm_neon, s16x8, /*arch_flags=*/0); TEST_MAX(arm_neon, f32x4, /*arch_flags=*/0); TEST_MAX(arm_neon, s32x4, /*arch_flags=*/0); +TEST_CONVERT(arm_neon, float, bfloat16, 8, 2, /*arch_flags=*/0); + TEST_HORIZONTAL_MIN(arm_neon, u8x16, /*arch_flags=*/0); TEST_HORIZONTAL_MIN(arm_neon, s8x16, /*arch_flags=*/0); TEST_HORIZONTAL_MIN(arm_neon, s16x8, /*arch_flags=*/0); diff --git a/ynnpack/base/simd/test/generic.h b/ynnpack/base/simd/test/generic.h index 551995db054..dce78f9f9b8 100644 --- a/ynnpack/base/simd/test/generic.h +++ b/ynnpack/base/simd/test/generic.h @@ -14,6 +14,7 @@ #include #include "ynnpack/base/arch.h" +#include "ynnpack/base/simd/multi_vec.h" #include "ynnpack/base/simd/vec.h" namespace ynn { @@ -259,7 +260,7 @@ void test_extract(uint32_t arch_flags) { test_extract(arch_flags); \ } -template +template void test_convert(uint32_t arch_flags) { if (!is_arch_supported(arch_flags)) { GTEST_SKIP() << "Unsupported architecture"; @@ -269,19 +270,30 @@ void test_convert(uint32_t arch_flags) { for (size_t i = 0; i < N; ++i) { src[i] = static_cast(i); } - vec from_v = load(src, vec{}); - vec to_v = convert(from_v, To{}); - To dst[N]; - store(dst, to_v); - for (size_t i = 0; i < N; ++i) { - ASSERT_EQ(dst[i], static_cast(src[i])); + if constexpr (M == 1) { + vec from_v = load(src, vec{}); + vec to_v = convert(from_v, To{}); + To dst[N]; + store(dst, to_v); + for (size_t i = 0; i < N; ++i) { + ASSERT_EQ(dst[i], static_cast(src[i])); + } + } else { + using ToVec = multi_vec, M>; + vec from_v = load(src, vec{}); + ToVec to_v = convert(from_v, To{}); + To dst[ToVec::N]; + store(dst, to_v, ToVec::N); + for (size_t i = 0; i < ToVec::N; ++i) { + ASSERT_EQ(dst[i], static_cast(src[i])); + } } } -#define TEST_CONVERT(test_class, to, from, N, arch_flags) \ - TEST(test_class, convert_##to##_##from) { \ - test_convert(arch_flags); \ +#define TEST_CONVERT(test_class, to, from, N, M, arch_flags) \ + TEST(test_class, convert_##to##_##from) { \ + test_convert(arch_flags); \ } // This function has a max of n at n, and descends to 0 at either 0 or 2*n - 1. diff --git a/ynnpack/base/simd/test/x86_avx2.cc b/ynnpack/base/simd/test/x86_avx2.cc index 0dc10bb827c..45687ad2e24 100644 --- a/ynnpack/base/simd/test/x86_avx2.cc +++ b/ynnpack/base/simd/test/x86_avx2.cc @@ -32,7 +32,7 @@ TEST_MAX(x86_avx2, s8x32, arch_flag::avx2); TEST_MAX(x86_avx2, s16x16, arch_flag::avx2); TEST_MAX(x86_avx2, s32x8, arch_flag::avx2); -TEST_CONVERT(x86_avx2, float, bfloat16, 8, arch_flag::avx2); +TEST_CONVERT(x86_avx2, float, bfloat16, 8, 1, arch_flag::avx2); TEST_HORIZONTAL_MIN(x86_avx2, u8x32, arch_flag::avx2); TEST_HORIZONTAL_MIN(x86_avx2, s8x32, arch_flag::avx2); diff --git a/ynnpack/base/simd/test/x86_avx512f.cc b/ynnpack/base/simd/test/x86_avx512f.cc index 900bd394ecb..0a6fef84dcc 100644 --- a/ynnpack/base/simd/test/x86_avx512f.cc +++ b/ynnpack/base/simd/test/x86_avx512f.cc @@ -66,8 +66,8 @@ TEST_EXTRACT(x86_avx512f, f16x16, f16x32, arch_flag::avx512f); TEST_EXTRACT(x86_avx512f, s8x32, s8x64, arch_flag::avx512f); TEST_EXTRACT(x86_avx512f, u8x32, u8x64, arch_flag::avx512f); -TEST_CONVERT(x86_avx512f, float, bfloat16, 16, arch_flag::avx512f); -TEST_CONVERT(x86_avx512f, float, half, 16, arch_flag::avx512f); +TEST_CONVERT(x86_avx512f, float, bfloat16, 16, 1, arch_flag::avx512f); +TEST_CONVERT(x86_avx512f, float, half, 16, 1, arch_flag::avx512f); TEST_HORIZONTAL_MIN(x86_avx512f, u8x64, arch_flag::avx512f); TEST_HORIZONTAL_MIN(x86_avx512f, s8x64, arch_flag::avx512f); diff --git a/ynnpack/kernels/reduce/arm_neon.cc b/ynnpack/kernels/reduce/arm_neon.cc index be0de465004..d716072217e 100644 --- a/ynnpack/kernels/reduce/arm_neon.cc +++ b/ynnpack/kernels/reduce/arm_neon.cc @@ -33,11 +33,9 @@ static f32x4x16 reduce_add( std::integral_constant /*horizontal_factor*/) { YNN_UNROLL for (int i = 0; i < 8; ++i) { - f32x4 lo(vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16))); - f32x4 hi(vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16))); - - a.v[2 * i + 0] += lo; - a.v[2 * i + 1] += hi; + f32x4x2 b_f32 = convert(b.v[i], float{}); + a.v[2 * i + 0] += extract<0>(b_f32, f32x4{}); + a.v[2 * i + 1] += extract<1>(b_f32, f32x4{}); } return a; @@ -60,12 +58,11 @@ static f32x4x16 reduce_add( std::integral_constant /*horizontal_factor*/) { YNN_UNROLL for (int i = 0; i < 8; ++i) { - float32x4_t lo = - vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16)); - float32x4_t hi = - vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16)); - a.v[2 * i + 0].v = vmlaq_f32(a.v[2 * i + 0].v, lo, lo); - a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, hi, hi); + f32x4x2 b_f32 = convert(b.v[i], float{}); + f32x4 lo = extract<0>(b_f32, f32x4{}); + f32x4 hi = extract<1>(b_f32, f32x4{}); + a.v[2 * i + 0].v = vmlaq_f32(a.v[2 * i + 0].v, lo.v, lo.v); + a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, hi.v, hi.v); } return a;