Skip to content

Commit d035310

Browse files
author
Fikret Ardal
committed
simd gather function with constant stride for every type and their tests
1 parent 57c3b3f commit d035310

File tree

6 files changed

+255
-62
lines changed

6 files changed

+255
-62
lines changed

c++/nda/simd/arch/AVX/functions.hpp

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,48 @@ namespace nda::simd {
547547
return -(x * y + z);
548548
}
549549

550+
//Gather functions.
551+
template <>
552+
inline simd_i8 gather(const simd_i8::value_t *from, const long stride) {
553+
simd_i8 simd_stride(static_cast<int32_t>(stride));
554+
const simd_i8 multiplier({0, 1, 2, 3, 4, 5, 6, 7});
555+
simd_i8 vindex = simd_stride * multiplier;
556+
return simd_i8(_mm256_i32gather_epi32(from, vindex, sizeof(simd_i8::value_t)));
557+
}
558+
559+
template <>
560+
inline simd_l4 gather(const simd_l4::value_t *from, const long stride) {
561+
simd_l4 simd_stride(stride);
562+
const simd_l4 multiplier({0, 1, 2, 3});
563+
simd_l4 vindex = simd_stride * multiplier;
564+
return simd_l4(_mm256_i64gather_epi64(reinterpret_cast<const long long int*>(from), vindex, sizeof(simd_l4::value_t)));
565+
}
566+
567+
template <>
568+
inline simd_f8 gather(const simd_f8::value_t *from, const long stride) {
569+
simd_i8 simd_stride(static_cast<int32_t>(stride));
570+
const simd_i8 multiplier({0, 1, 2, 3, 4, 5, 6, 7});
571+
simd_i8 vindex = simd_stride * multiplier;
572+
return simd_f8(_mm256_i32gather_ps(from, vindex, sizeof(simd_f8::value_t)));
573+
}
574+
575+
template <>
576+
inline simd_d4 gather(const simd_d4::value_t *from, const long stride) {
577+
simd_l4 simd_stride(stride);
578+
const simd_l4 multiplier({0, 1, 2, 3});
579+
simd_l4 vindex = simd_stride * multiplier;
580+
return simd_d4(_mm256_i64gather_pd(from, vindex, sizeof(simd_d4::value_t)));
581+
}
582+
583+
template <>
584+
inline simd_cf4 gather(const simd_cf4::value_t *from, const long stride) {
585+
return simd_cf4(_mm256_castpd_ps(gather<simd_d4>(reinterpret_cast<const simd_d4::value_t *>(from), stride)));
586+
}
587+
588+
template <>
589+
inline simd_cd2 gather(const simd_cd2::value_t *from, const long stride) {
590+
return simd_cd2(_mm256_set_pd(from[stride].imag(), from[stride].real(), from[0].imag(), from[0].real()));
591+
}
550592

551593
} // namespace nda::simd
552594
#endif

c++/nda/simd/arch/AVX512/functions.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,56 @@ namespace nda::simd {
446446
inline simd_l8 fma_nsub(const simd_l8 &x, const simd_l8 &y, const simd_l8 &z) {
447447
return -(x * y + z);
448448
}
449+
//Gather functions.
450+
template <>
451+
inline simd_i16 gather(const simd_i16::value_t *from, const long stride) {
452+
simd_i16 simd_stride(static_cast<int32_t>(stride));
453+
const simd_i16 multiplier({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
454+
simd_i16 vindex = simd_stride * multiplier;
455+
return simd_l8(_mm512_i64gather_epi64(vindex, from, sizeof(simd_l8::value_t)));
456+
}
457+
458+
template <>
459+
inline simd_l8 gather(const simd_l8::value_t *from, const long stride) {
460+
simd_l8 simd_stride(stride);
461+
const simd_l8 multiplier({0, 1, 2, 3, 4, 5, 6, 7});
462+
simd_l8 vindex = simd_stride * multiplier;
463+
return simd_l8(_mm512_i64gather_epi64(vindex, from, sizeof(simd_l8::value_t)));
464+
}
465+
466+
template <>
467+
inline simd_f16 gather(const simd_f16::value_t *from, const long stride) {
468+
simd_i16 simd_stride(static_cast<int32_t>(stride));
469+
const simd_i16 multiplier({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
470+
simd_i16 vindex = simd_stride * multiplier;
471+
return simd_f16(_mm512_i32gather_ps(vindex, from, sizeof(simd_f16::value_t)));
472+
}
473+
474+
template <>
475+
inline simd_d8 gather(const simd_d8::value_t *from, const long stride) {
476+
simd_l8 simd_stride(stride);
477+
const simd_l8 multiplier({0, 1, 2, 3, 4, 5, 6, 7});
478+
simd_l8 vindex = simd_stride * multiplier;
479+
return simd_d8(_mm512_i64gather_pd(vindex, from, sizeof(simd_d8::value_t)));
480+
}
481+
482+
template <>
483+
inline simd_cf8 gather(const simd_cf8::value_t *from, const long stride) {
484+
return simd_cf8(_mm512_castpd_ps(gather<simd_d8>(reinterpret_cast<const simd_d8::value_t*>(from), stride)));
485+
}
486+
487+
template <>
488+
inline simd_cd4 gather(const simd_cd4::value_t *from, const long stride) {
489+
simd_cd1 a,b,c,d;
490+
a.load_unaligned(from);
491+
b.load_unaligned(from + stride);
492+
c.load_unaligned(from + 2 * stride);
493+
d.load_unaligned(from + 3 * stride);
494+
__m256d ab = _mm256_insertf128_pd(_mm256_castpd128_pd256(a), b, 1);
495+
__m256d cd = _mm256_insertf128_pd(_mm256_castpd128_pd256(c), d, 1);
496+
return simd_cd4(_mm512_insertf64x4(_mm512_castpd256_pd512(ab), cd , 1))
497+
498+
}
449499

450500
} // namespace nda::simd
451501
#endif

c++/nda/simd/arch/Default/functions.hpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,4 +344,36 @@ namespace nda::simd {
344344
inline simd_l1 fma_nsub(const simd_l1 &x, const simd_l1 &y, const simd_l1 &z) {
345345
return -(x * y + z);
346346
}
347+
348+
//Gather functions
349+
template <>
350+
inline simd_i1 gather(const simd_i1::value_t *from, [[maybe_unused]]const long stride) {
351+
return simd_i1(from);
352+
}
353+
354+
template <>
355+
inline simd_l1 gather(const simd_l1::value_t *from, [[maybe_unused]]const long stride) {
356+
return simd_l1(from);
357+
}
358+
359+
template <>
360+
inline simd_f1 gather(const simd_f1::value_t *from,[[maybe_unused]] const long stride) {
361+
return simd_f1(from);
362+
}
363+
364+
template <>
365+
inline simd_d1 gather(const simd_d1::value_t *from, [[maybe_unused]] const long stride) {
366+
return simd_d1(from);
367+
}
368+
369+
template <>
370+
inline simd_cf1 gather(const simd_cf1::value_t *from, [[maybe_unused]] const long stride) {
371+
return simd_cf1(from);
372+
}
373+
374+
template <>
375+
inline simd_cd1_d gather(const simd_cd1_d::value_t *from,[[maybe_unused]] const long stride) {
376+
return simd_cd1_d(from);
377+
}
378+
347379
} // namespace nda::simd

c++/nda/simd/arch/SSE/functions.hpp

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -367,21 +367,21 @@ namespace nda::simd {
367367

368368
template <>
369369
inline simd_cf2 fma_nadd(const simd_cf2 &x, const simd_cf2 &y, const simd_cf2 &z) {
370-
__m128 x_odd = _mm_movehdup_ps(x);
371-
__m128 x_even = _mm_moveldup_ps(x);
372-
__m128 y_swap = _mm_permute_ps(y, NDA_SHUFFLE_MASK4(1,0,3,2));
370+
__m128 x_odd = _mm_movehdup_ps(x);
371+
__m128 x_even = _mm_moveldup_ps(x);
372+
__m128 y_swap = _mm_permute_ps(y, NDA_SHUFFLE_MASK4(1, 0, 3, 2));
373373
simd_cf2 y_swap_conj = conj(simd_cf2(y_swap)); // TODO: Eigen bug create issue maybe in eigen.
374-
__m128 result = _mm_fmsub_ps(x_odd, y_swap_conj, _mm_fmsub_ps(x_even, y, z));
374+
__m128 result = _mm_fmsub_ps(x_odd, y_swap_conj, _mm_fmsub_ps(x_even, y, z));
375375
return simd_cf2(result);
376376
}
377377

378378
template <>
379379
inline simd_cd1 fma_nadd(const simd_cd1 &x, const simd_cd1 &y, const simd_cd1 &z) {
380-
__m128d x_odd = _mm_permute_pd(x, 0x3);
381-
__m128d x_even = _mm_movedup_pd(x);
382-
__m128d y_swap = _mm_permute_pd(y, 0x1);
380+
__m128d x_odd = _mm_permute_pd(x, 0x3);
381+
__m128d x_even = _mm_movedup_pd(x);
382+
__m128d y_swap = _mm_permute_pd(y, 0x1);
383383
simd_cd1 y_swap_conj = conj(simd_cd1(y_swap));
384-
__m128d result = _mm_fmsub_pd(x_odd, y_swap_conj, _mm_fmsub_pd(x_even, y, z));
384+
__m128d result = _mm_fmsub_pd(x_odd, y_swap_conj, _mm_fmsub_pd(x_even, y, z));
385385
return simd_cd1(result);
386386
}
387387

@@ -398,19 +398,19 @@ namespace nda::simd {
398398

399399
template <>
400400
inline simd_cf2 fma_nsub(const simd_cf2 &x, const simd_cf2 &y, const simd_cf2 &z) {
401-
__m128 x_odd = _mm_movehdup_ps(x);
402-
__m128 x_even = _mm_moveldup_ps(x);
403-
__m128 y_swap = _mm_permute_ps(y, NDA_SHUFFLE_MASK4(1, 0, 3, 2));
401+
__m128 x_odd = _mm_movehdup_ps(x);
402+
__m128 x_even = _mm_moveldup_ps(x);
403+
__m128 y_swap = _mm_permute_ps(y, NDA_SHUFFLE_MASK4(1, 0, 3, 2));
404404
simd_cf2 y_swap_conj = conj(simd_cf2(y_swap));
405-
__m128 result = _mm_fmsub_ps(x_odd, y_swap_conj, _mm_fmadd_ps(x_even, y, z));
405+
__m128 result = _mm_fmsub_ps(x_odd, y_swap_conj, _mm_fmadd_ps(x_even, y, z));
406406
return simd_cf2(result);
407407
}
408408

409409
template <>
410410
inline simd_cd1 fma_nsub(const simd_cd1 &x, const simd_cd1 &y, const simd_cd1 &z) {
411-
__m128d x_odd = _mm_permute_pd(x, 0x3);
412-
__m128d x_even = _mm_movedup_pd(x);
413-
__m128d y_swap = _mm_permute_pd(y, 0x1);
411+
__m128d x_odd = _mm_permute_pd(x, 0x3);
412+
__m128d x_even = _mm_movedup_pd(x);
413+
__m128d y_swap = _mm_permute_pd(y, 0x1);
414414
simd_cd1 y_swap_conj = conj(simd_cd1(y_swap));
415415

416416
__m128d result = _mm_fmsub_pd(x_odd, y_swap_conj, _mm_fmadd_pd(x_even, y, z));
@@ -546,5 +546,52 @@ namespace nda::simd {
546546
inline simd_l2 fma_nsub(const simd_l2 &x, const simd_l2 &y, const simd_l2 &z) {
547547
return -(x * y + z);
548548
}
549+
// Gather Functions with given strides in vindex.
550+
#ifdef __AVX2__
551+
template <>
552+
inline simd_i4 gather(const simd_i4::value_t *from, const long stride) {
553+
simd_i4 simd_stride(static_cast<int32_t>(stride));
554+
const simd_i4 multiplier({0, 1, 2, 3});
555+
simd_i4 vindex = simd_stride * multiplier;
556+
return simd_i4(_mm_i32gather_epi32(from, vindex, sizeof(simd_i4::value_t)));
557+
}
558+
559+
template <>
560+
inline simd_l2 gather(const simd_l2::value_t * from, const long stride) {
561+
simd_l2 simd_stride(stride);
562+
const simd_l2 multiplier({0, 1});
563+
simd_l2 vindex = simd_stride * multiplier;
564+
return simd_l2(_mm_i64gather_epi64(reinterpret_cast<const long long int*>(from), vindex, sizeof(simd_l2::value_t)));
565+
}
566+
567+
template <>
568+
inline simd_f4 gather(const simd_f4::value_t *from, const long stride) {
569+
simd_i4 simd_stride(static_cast<int32_t>(stride));
570+
const simd_i4 multiplier({0, 1, 2, 3});
571+
simd_i4 vindex = simd_stride * multiplier;
572+
return simd_f4(_mm_i32gather_ps(from, vindex, sizeof(simd_f4::value_t)));
573+
}
574+
575+
template <>
576+
inline simd_d2 gather(const simd_d2::value_t *from, const long stride) {
577+
simd_l2 simd_stride(stride);
578+
const simd_l2 multiplier({0, 1});
579+
simd_l2 vindex = simd_stride * multiplier;
580+
return simd_d2(_mm_i64gather_pd(from, vindex, sizeof(simd_d2::value_t)));
581+
}
582+
583+
template <>
584+
inline simd_cf2 gather(const simd_cf2::value_t *from, const long stride) {
585+
return simd_cf2(_mm_castpd_ps(gather<simd_d2>(reinterpret_cast<const simd_d2::value_t *>(from), stride)));
586+
}
587+
588+
template <>
589+
inline simd_cd1 gather(const simd_cd1::value_t *from, [[maybe_unused]] const long stride) {
590+
simd_cd1 tmp;
591+
tmp.load_unaligned(from);
592+
return tmp;
593+
}
594+
595+
#endif
549596
} // namespace nda::simd
550597
#endif

c++/nda/simd/arch/functions_forward.hpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#pragma once
22

3+
#include <array>
4+
35
namespace nda::simd {
46
template <typename T>
57
T abs(const T &);
@@ -28,19 +30,19 @@ namespace nda::simd {
2830
template <typename T>
2931
typename T::value_t reduce_mul(const T &);
3032

31-
//TODO: Implement these
32-
template<typename T>
33-
T fma_add(const T&, const T&, const T&);
34-
35-
template<typename T>
36-
T fma_sub(const T&, const T&, const T&);
33+
template <typename T>
34+
T fma_add(const T &, const T &, const T &);
3735

38-
template<typename T>
39-
T fma_nadd(const T&, const T&, const T&);
36+
template <typename T>
37+
T fma_sub(const T &, const T &, const T &);
4038

41-
template<typename T>
42-
T fma_nsub(const T&, const T&, const T&);
39+
template <typename T>
40+
T fma_nadd(const T &, const T &, const T &);
4341

42+
template <typename T>
43+
T fma_nsub(const T &, const T &, const T &);
4444

45+
template <typename T>
46+
T gather(const typename T::value_t *, const long);
4547

4648
} // namespace nda::simd

0 commit comments

Comments
 (0)