1414
1515#include < gtest/gtest.h>
1616#include " ynnpack/base/arch.h"
17+ #include " ynnpack/base/simd/multi_vec.h"
1718#include " ynnpack/base/simd/vec.h"
1819
1920namespace ynn {
@@ -259,7 +260,7 @@ void test_extract(uint32_t arch_flags) {
259260 test_extract<to, from>(arch_flags); \
260261 }
261262
262- template <typename To, typename From, size_t N>
263+ template <typename To, typename From, size_t N, size_t M >
263264void test_convert (uint32_t arch_flags) {
264265 if (!is_arch_supported (arch_flags)) {
265266 GTEST_SKIP () << " Unsupported architecture" ;
@@ -269,19 +270,30 @@ void test_convert(uint32_t arch_flags) {
269270 for (size_t i = 0 ; i < N; ++i) {
270271 src[i] = static_cast <From>(i);
271272 }
272- vec<From, N> from_v = load (src, vec<From, N>{});
273- vec<To, N> to_v = convert (from_v, To{});
274273
275- To dst[N];
276- store (dst, to_v);
277- for (size_t i = 0 ; i < N; ++i) {
278- ASSERT_EQ (dst[i], static_cast <To>(src[i]));
274+ if constexpr (M == 1 ) {
275+ vec<From, N> from_v = load (src, vec<From, N>{});
276+ vec<To, N> to_v = convert (from_v, To{});
277+ To dst[N];
278+ store (dst, to_v);
279+ for (size_t i = 0 ; i < N; ++i) {
280+ ASSERT_EQ (dst[i], static_cast <To>(src[i]));
281+ }
282+ } else {
283+ using ToVec = multi_vec<vec<To, N / M>, M>;
284+ vec<From, N> from_v = load (src, vec<From, N>{});
285+ ToVec to_v = convert (from_v, To{});
286+ To dst[ToVec::N];
287+ store (dst, to_v, ToVec::N);
288+ for (size_t i = 0 ; i < ToVec::N; ++i) {
289+ ASSERT_EQ (dst[i], static_cast <To>(src[i]));
290+ }
279291 }
280292}
281293
282- #define TEST_CONVERT (test_class, to, from, N, arch_flags ) \
283- TEST (test_class, convert_##to##_##from) { \
284- test_convert<to, from, N>(arch_flags); \
294+ #define TEST_CONVERT (test_class, to, from, N, M, arch_flags ) \
295+ TEST (test_class, convert_##to##_##from) { \
296+ test_convert<to, from, N, M >(arch_flags); \
285297 }
286298
287299// This function has a max of n at n, and descends to 0 at either 0 or 2*n - 1.
0 commit comments