Skip to content

Commit 6787999

Browse files
Aelphyxnnpack-bot
authored andcommitted
Added bf16x8->f32x4x2 convert for arm neon.
PiperOrigin-RevId: 844040217
1 parent 7746844 commit 6787999

File tree

6 files changed

+47
-27
lines changed

6 files changed

+47
-27
lines changed

ynnpack/base/simd/arm_neon.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
// This source code is licensed under the BSD-style license found in the
44
// LICENSE file in the root directory of this source tree.
55

6-
#ifndef XNNPACK_YNNPACK_BASE_SIMD_ARM_H_
7-
#define XNNPACK_YNNPACK_BASE_SIMD_ARM_H_
6+
#ifndef XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_
7+
#define XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_
88

99
#include <arm_neon.h>
1010

@@ -18,6 +18,7 @@
1818
#include "ynnpack/base/base.h"
1919
#include "ynnpack/base/bfloat16.h"
2020
#include "ynnpack/base/half.h"
21+
#include "ynnpack/base/simd/multi_vec.h"
2122
#include "ynnpack/base/simd/vec.h"
2223

2324
namespace ynn {
@@ -467,6 +468,14 @@ YNN_ALWAYS_INLINE s8x16 max(s8x16 a, s8x16 b) {
467468
return s8x16{vmaxq_s8(a.v, b.v)};
468469
}
469470

471+
using f32x4x2 = multi_vec<f32x4, 2>;
472+
YNN_ALWAYS_INLINE f32x4x2 convert(bf16x8 a, float) {
473+
f32x4x2 result;
474+
result.v[0].v = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a.v), 16));
475+
result.v[1].v = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a.v), 16));
476+
return result;
477+
}
478+
470479
#ifdef YNN_ARCH_ARM32
471480
YNN_ALWAYS_INLINE float vmaxvq_f32(float32x4_t a) {
472481
float32x2_t max_halves = vmax_f32(vget_low_f32(a), vget_high_f32(a));
@@ -562,4 +571,4 @@ YNN_ALWAYS_INLINE std::array<vec<T, 4>, 4> transpose(
562571

563572
} // namespace ynn
564573

565-
#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_H_
574+
#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_

ynnpack/base/simd/test/arm_neon.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ TEST_MAX(arm_neon, s16x8, /*arch_flags=*/0);
6969
TEST_MAX(arm_neon, f32x4, /*arch_flags=*/0);
7070
TEST_MAX(arm_neon, s32x4, /*arch_flags=*/0);
7171

72+
TEST_CONVERT(arm_neon, float, bfloat16, 8, 2, /*arch_flags=*/0);
73+
7274
TEST_HORIZONTAL_MIN(arm_neon, u8x16, /*arch_flags=*/0);
7375
TEST_HORIZONTAL_MIN(arm_neon, s8x16, /*arch_flags=*/0);
7476
TEST_HORIZONTAL_MIN(arm_neon, s16x8, /*arch_flags=*/0);

ynnpack/base/simd/test/generic.h

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
#include <gtest/gtest.h>
1616
#include "ynnpack/base/arch.h"
17+
#include "ynnpack/base/simd/multi_vec.h"
1718
#include "ynnpack/base/simd/vec.h"
1819

1920
namespace ynn {
@@ -259,7 +260,7 @@ void test_extract(uint32_t arch_flags) {
259260
test_extract<to, from>(arch_flags); \
260261
}
261262

262-
template <typename To, typename From, size_t N>
263+
template <typename To, typename From, size_t N, size_t M>
263264
void test_convert(uint32_t arch_flags) {
264265
if (!is_arch_supported(arch_flags)) {
265266
GTEST_SKIP() << "Unsupported architecture";
@@ -269,19 +270,30 @@ void test_convert(uint32_t arch_flags) {
269270
for (size_t i = 0; i < N; ++i) {
270271
src[i] = static_cast<From>(i);
271272
}
272-
vec<From, N> from_v = load(src, vec<From, N>{});
273-
vec<To, N> to_v = convert(from_v, To{});
274273

275-
To dst[N];
276-
store(dst, to_v);
277-
for (size_t i = 0; i < N; ++i) {
278-
ASSERT_EQ(dst[i], static_cast<To>(src[i]));
274+
if constexpr (M == 1) {
275+
vec<From, N> from_v = load(src, vec<From, N>{});
276+
vec<To, N> to_v = convert(from_v, To{});
277+
To dst[N];
278+
store(dst, to_v);
279+
for (size_t i = 0; i < N; ++i) {
280+
ASSERT_EQ(dst[i], static_cast<To>(src[i]));
281+
}
282+
} else {
283+
using ToVec = multi_vec<vec<To, N / M>, M>;
284+
vec<From, N> from_v = load(src, vec<From, N>{});
285+
ToVec to_v = convert(from_v, To{});
286+
To dst[ToVec::N];
287+
store(dst, to_v, ToVec::N);
288+
for (size_t i = 0; i < ToVec::N; ++i) {
289+
ASSERT_EQ(dst[i], static_cast<To>(src[i]));
290+
}
279291
}
280292
}
281293

282-
#define TEST_CONVERT(test_class, to, from, N, arch_flags) \
283-
TEST(test_class, convert_##to##_##from) { \
284-
test_convert<to, from, N>(arch_flags); \
294+
#define TEST_CONVERT(test_class, to, from, N, M, arch_flags) \
295+
TEST(test_class, convert_##to##_##from) { \
296+
test_convert<to, from, N, M>(arch_flags); \
285297
}
286298

287299
// This function has a max of n at n, and descends to 0 at either 0 or 2*n - 1.

ynnpack/base/simd/test/x86_avx2.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ TEST_MAX(x86_avx2, s8x32, arch_flag::avx2);
3232
TEST_MAX(x86_avx2, s16x16, arch_flag::avx2);
3333
TEST_MAX(x86_avx2, s32x8, arch_flag::avx2);
3434

35-
TEST_CONVERT(x86_avx2, float, bfloat16, 8, arch_flag::avx2);
35+
TEST_CONVERT(x86_avx2, float, bfloat16, 8, 1, arch_flag::avx2);
3636

3737
TEST_HORIZONTAL_MIN(x86_avx2, u8x32, arch_flag::avx2);
3838
TEST_HORIZONTAL_MIN(x86_avx2, s8x32, arch_flag::avx2);

ynnpack/base/simd/test/x86_avx512f.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ TEST_EXTRACT(x86_avx512f, f16x16, f16x32, arch_flag::avx512f);
6666
TEST_EXTRACT(x86_avx512f, s8x32, s8x64, arch_flag::avx512f);
6767
TEST_EXTRACT(x86_avx512f, u8x32, u8x64, arch_flag::avx512f);
6868

69-
TEST_CONVERT(x86_avx512f, float, bfloat16, 16, arch_flag::avx512f);
70-
TEST_CONVERT(x86_avx512f, float, half, 16, arch_flag::avx512f);
69+
TEST_CONVERT(x86_avx512f, float, bfloat16, 16, 1, arch_flag::avx512f);
70+
TEST_CONVERT(x86_avx512f, float, half, 16, 1, arch_flag::avx512f);
7171

7272
TEST_HORIZONTAL_MIN(x86_avx512f, u8x64, arch_flag::avx512f);
7373
TEST_HORIZONTAL_MIN(x86_avx512f, s8x64, arch_flag::avx512f);

ynnpack/kernels/reduce/arm_neon.cc

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,9 @@ static f32x4x16 reduce_add(
3333
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
3434
YNN_UNROLL
3535
for (int i = 0; i < 8; ++i) {
36-
f32x4 lo(vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16)));
37-
f32x4 hi(vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16)));
38-
39-
a.v[2 * i + 0] += lo;
40-
a.v[2 * i + 1] += hi;
36+
f32x4x2 b_f32 = convert(b.v[i], float{});
37+
a.v[2 * i + 0] += extract<0>(b_f32.v, f32x4{});
38+
a.v[2 * i + 1] += extract<1>(b_f32.v, f32x4{});
4139
}
4240

4341
return a;
@@ -60,12 +58,11 @@ static f32x4x16 reduce_add(
6058
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
6159
YNN_UNROLL
6260
for (int i = 0; i < 8; ++i) {
63-
float32x4_t lo =
64-
vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16));
65-
float32x4_t hi =
66-
vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16));
67-
a.v[2 * i + 0].v = vmlaq_f32(a.v[2 * i + 0].v, lo, lo);
68-
a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, hi, hi);
61+
f32x4x2 b_f32 = convert(b.v[i], float{});
62+
f32x4 lo = extract<0>(b_f32.v, f32x4{});
63+
f32x4 hi = extract<1>(b_f32.v, f32x4{});
64+
a.v[2 * i + 0].v = vmlaq_f32(a.v[2 * i + 0].v, lo.v, lo.v);
65+
a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, hi.v, hi.v);
6966
}
7067

7168
return a;

0 commit comments

Comments
 (0)