Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions ynnpack/base/simd/arm_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#ifndef XNNPACK_YNNPACK_BASE_SIMD_ARM_H_
#define XNNPACK_YNNPACK_BASE_SIMD_ARM_H_
#ifndef XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_
#define XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_

#include <arm_neon.h>

Expand All @@ -18,6 +18,7 @@
#include "ynnpack/base/base.h"
#include "ynnpack/base/bfloat16.h"
#include "ynnpack/base/half.h"
#include "ynnpack/base/simd/multi_vec.h"
#include "ynnpack/base/simd/vec.h"

namespace ynn {
Expand Down Expand Up @@ -467,6 +468,14 @@ YNN_ALWAYS_INLINE s8x16 max(s8x16 a, s8x16 b) {
return s8x16{vmaxq_s8(a.v, b.v)};
}

using f32x4x2 = multi_vec<f32x4, 2>;
YNN_ALWAYS_INLINE f32x4x2 convert(bf16x8 a, float) {
f32x4x2 result;
result.v[0].v = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a.v), 16));
result.v[1].v = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a.v), 16));
return result;
}

#ifdef YNN_ARCH_ARM32
YNN_ALWAYS_INLINE float vmaxvq_f32(float32x4_t a) {
float32x2_t max_halves = vmax_f32(vget_low_f32(a), vget_high_f32(a));
Expand Down Expand Up @@ -562,4 +571,4 @@ YNN_ALWAYS_INLINE std::array<vec<T, 4>, 4> transpose(

} // namespace ynn

#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_H_
#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_
2 changes: 2 additions & 0 deletions ynnpack/base/simd/test/arm_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ TEST_MAX(arm_neon, s16x8, /*arch_flags=*/0);
TEST_MAX(arm_neon, f32x4, /*arch_flags=*/0);
TEST_MAX(arm_neon, s32x4, /*arch_flags=*/0);

TEST_CONVERT(arm_neon, float, bfloat16, 8, 2, /*arch_flags=*/0);

TEST_HORIZONTAL_MIN(arm_neon, u8x16, /*arch_flags=*/0);
TEST_HORIZONTAL_MIN(arm_neon, s8x16, /*arch_flags=*/0);
TEST_HORIZONTAL_MIN(arm_neon, s16x8, /*arch_flags=*/0);
Expand Down
32 changes: 22 additions & 10 deletions ynnpack/base/simd/test/generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <gtest/gtest.h>
#include "ynnpack/base/arch.h"
#include "ynnpack/base/simd/multi_vec.h"
#include "ynnpack/base/simd/vec.h"

namespace ynn {
Expand Down Expand Up @@ -259,7 +260,7 @@ void test_extract(uint32_t arch_flags) {
test_extract<to, from>(arch_flags); \
}

template <typename To, typename From, size_t N>
template <typename To, typename From, size_t N, size_t M>
void test_convert(uint32_t arch_flags) {
if (!is_arch_supported(arch_flags)) {
GTEST_SKIP() << "Unsupported architecture";
Expand All @@ -269,19 +270,30 @@ void test_convert(uint32_t arch_flags) {
for (size_t i = 0; i < N; ++i) {
src[i] = static_cast<From>(i);
}
vec<From, N> from_v = load(src, vec<From, N>{});
vec<To, N> to_v = convert(from_v, To{});

To dst[N];
store(dst, to_v);
for (size_t i = 0; i < N; ++i) {
ASSERT_EQ(dst[i], static_cast<To>(src[i]));
if constexpr (M == 1) {
vec<From, N> from_v = load(src, vec<From, N>{});
vec<To, N> to_v = convert(from_v, To{});
To dst[N];
store(dst, to_v);
for (size_t i = 0; i < N; ++i) {
ASSERT_EQ(dst[i], static_cast<To>(src[i]));
}
} else {
using ToVec = multi_vec<vec<To, N / M>, M>;
vec<From, N> from_v = load(src, vec<From, N>{});
ToVec to_v = convert(from_v, To{});
To dst[ToVec::N];
store(dst, to_v, ToVec::N);
for (size_t i = 0; i < ToVec::N; ++i) {
ASSERT_EQ(dst[i], static_cast<To>(src[i]));
}
}
}

#define TEST_CONVERT(test_class, to, from, N, arch_flags) \
TEST(test_class, convert_##to##_##from) { \
test_convert<to, from, N>(arch_flags); \
#define TEST_CONVERT(test_class, to, from, N, M, arch_flags) \
TEST(test_class, convert_##to##_##from) { \
test_convert<to, from, N, M>(arch_flags); \
}

// This function has a max of n at n, and descends to 0 at either 0 or 2*n - 1.
Expand Down
2 changes: 1 addition & 1 deletion ynnpack/base/simd/test/x86_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ TEST_MAX(x86_avx2, s8x32, arch_flag::avx2);
TEST_MAX(x86_avx2, s16x16, arch_flag::avx2);
TEST_MAX(x86_avx2, s32x8, arch_flag::avx2);

TEST_CONVERT(x86_avx2, float, bfloat16, 8, arch_flag::avx2);
TEST_CONVERT(x86_avx2, float, bfloat16, 8, 1, arch_flag::avx2);

TEST_HORIZONTAL_MIN(x86_avx2, u8x32, arch_flag::avx2);
TEST_HORIZONTAL_MIN(x86_avx2, s8x32, arch_flag::avx2);
Expand Down
4 changes: 2 additions & 2 deletions ynnpack/base/simd/test/x86_avx512f.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ TEST_EXTRACT(x86_avx512f, f16x16, f16x32, arch_flag::avx512f);
TEST_EXTRACT(x86_avx512f, s8x32, s8x64, arch_flag::avx512f);
TEST_EXTRACT(x86_avx512f, u8x32, u8x64, arch_flag::avx512f);

TEST_CONVERT(x86_avx512f, float, bfloat16, 16, arch_flag::avx512f);
TEST_CONVERT(x86_avx512f, float, half, 16, arch_flag::avx512f);
TEST_CONVERT(x86_avx512f, float, bfloat16, 16, 1, arch_flag::avx512f);
TEST_CONVERT(x86_avx512f, float, half, 16, 1, arch_flag::avx512f);

TEST_HORIZONTAL_MIN(x86_avx512f, u8x64, arch_flag::avx512f);
TEST_HORIZONTAL_MIN(x86_avx512f, s8x64, arch_flag::avx512f);
Expand Down
19 changes: 8 additions & 11 deletions ynnpack/kernels/reduce/arm_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ static f32x4x16 reduce_add(
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
YNN_UNROLL
for (int i = 0; i < 8; ++i) {
f32x4 lo(vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16)));
f32x4 hi(vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16)));

a.v[2 * i + 0] += lo;
a.v[2 * i + 1] += hi;
f32x4x2 b_f32 = convert(b.v[i], float{});
a.v[2 * i + 0] += extract<0>(b_f32, f32x4{});
a.v[2 * i + 1] += extract<1>(b_f32, f32x4{});
}

return a;
Expand All @@ -60,12 +58,11 @@ static f32x4x16 reduce_add(
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
YNN_UNROLL
for (int i = 0; i < 8; ++i) {
float32x4_t lo =
vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16));
float32x4_t hi =
vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16));
a.v[2 * i + 0].v = vmlaq_f32(a.v[2 * i + 0].v, lo, lo);
a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, hi, hi);
f32x4x2 b_f32 = convert(b.v[i], float{});
f32x4 lo = extract<0>(b_f32, f32x4{});
f32x4 hi = extract<1>(b_f32, f32x4{});
a.v[2 * i + 0].v = vmlaq_f32(a.v[2 * i + 0].v, lo.v, lo.v);
a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, hi.v, hi.v);
}

return a;
Expand Down
Loading