@@ -24,6 +24,8 @@ namespace { // anonymous namespace for internal structs equivalent to declaring
2424 // static
2525struct zip_low ;
2626struct zip_hi ;
27+ struct select_even ;
28+ struct select_odd ;
2729// forward declaration to clean up the code and be able to use this everywhere in the file
2830template <class T , uint8_t N, uint8_t K = N> static constexpr auto BestSIMDHelper ();
2931template <class T , uint8_t N> constexpr auto GetPaddedSIMDWidth ();
@@ -43,6 +45,12 @@ constexpr auto zip_low_index =
4345template <class arch_t >
4446constexpr auto zip_hi_index =
4547 xsimd::make_batch_constant<xsimd::as_unsigned_integer_t <FLT>, arch_t , zip_hi>();
48+ template <class arch_t >
49+ constexpr auto select_even_mask =
50+ xsimd::make_batch_bool_constant<FLT, arch_t , select_even>();
51+ template <class arch_t >
52+ constexpr auto select_odd_mask =
53+ xsimd::make_batch_bool_constant<FLT, arch_t , select_odd>();
4654template <typename T, std::size_t N, std::size_t M, std::size_t PaddedM>
4755constexpr std::array<std::array<T, PaddedM>, N> pad_2D_array_with_zeros (
4856 const std::array<std::array<T, M>, N> &input) noexcept ;
@@ -67,14 +75,14 @@ template<uint8_t w, class simd_type = xsimd::make_sized_batch_t<
6775 FLT, find_optimal_simd_width<FLT, w>()>> // aka ns
6876static FINUFFT_ALWAYS_INLINE void eval_kernel_vec_Horner (
6977 FLT *FINUFFT_RESTRICT ker, FLT x, const finufft_spread_opts &opts) noexcept ;
70- template <uint8_t ns>
71- static void interp_line (FLT *FINUFFT_RESTRICT out , const FLT *du, const FLT *ker,
78+ template <uint8_t ns, class simd_type = PaddedSIMD<FLT, 2 * ns> >
79+ static void interp_line (FLT *FINUFFT_RESTRICT target , const FLT *du, const FLT *ker,
7280 BIGINT i1, BIGINT N1);
7381template <uint8_t ns>
74- static void interp_square (FLT *FINUFFT_RESTRICT out , const FLT *du, const FLT *ker1,
82+ static void interp_square (FLT *FINUFFT_RESTRICT target , const FLT *du, const FLT *ker1,
7583 const FLT *ker2, BIGINT i1, BIGINT i2, BIGINT N1, BIGINT N2);
7684template <uint8_t ns>
77- static void interp_cube (FLT *FINUFFT_RESTRICT out , const FLT *du, const FLT *ker1,
85+ static void interp_cube (FLT *FINUFFT_RESTRICT target , const FLT *du, const FLT *ker1,
7886 const FLT *ker2, const FLT *ker3, BIGINT i1, BIGINT i2, BIGINT i3,
7987 BIGINT N1, BIGINT N2, BIGINT N3);
8088static void spread_subproblem_1d (BIGINT off1, BIGINT size1, FLT *du0, BIGINT M0, FLT *kx0,
@@ -454,10 +462,10 @@ FINUFFT_NEVER_INLINE static int interpSorted_kernel(
454462// Interpolate to NU pts in sorted order from a uniform grid.
455463// See spreadinterp() for doc.
456464{
457- using simd_type = xsimd::batch <FLT>;
465+ using simd_type = PaddedSIMD <FLT, 2 * ns >;
458466 using arch_t = typename simd_type::arch_type;
459467 static constexpr auto padding = get_padding<FLT, 2 * ns>();
460- static constexpr auto alignment = simd_type::arch_type ::alignment ();
468+ static constexpr auto alignment = arch_t ::alignment ();
461469 static constexpr auto simd_size = simd_type::size;
462470 static constexpr auto ns2 = ns * FLT (0.5 ); // half spread width, used as stencil shift
463471
@@ -521,15 +529,16 @@ FINUFFT_NEVER_INLINE static int interpSorted_kernel(
521529 if (!(opts.flags & TF_OMIT_SPREADING)) {
522530 switch (ndims) {
523531 case 1 :
524- ker_eval<ns, kerevalmeth, FLT>(kernel_values.data (), opts, x1);
525- interp_line<ns>(target, data_uniform, ker1, i1, N1);
532+ ker_eval<ns, kerevalmeth, FLT, simd_type >(kernel_values.data (), opts, x1);
533+ interp_line<ns, simd_type >(target, data_uniform, ker1, i1, N1);
526534 break ;
527535 case 2 :
528- ker_eval<ns, kerevalmeth, FLT>(kernel_values.data (), opts, x1, x2);
536+ ker_eval<ns, kerevalmeth, FLT, simd_type >(kernel_values.data (), opts, x1, x2);
529537 interp_square<ns>(target, data_uniform, ker1, ker2, i1, i2, N1, N2);
530538 break ;
531539 case 3 :
532- ker_eval<ns, kerevalmeth, FLT>(kernel_values.data (), opts, x1, x2, x3);
540+ ker_eval<ns, kerevalmeth, FLT, simd_type>(kernel_values.data (), opts, x1, x2,
541+ x3);
533542 interp_cube<ns>(target, data_uniform, ker1, ker2, ker3, i1, i2, i3, N1, N2,
534543 N3);
535544 break ;
@@ -788,9 +797,9 @@ Two upsampfacs implemented. Params must match ref formula. Barnett 4/24/18 */
788797 }
789798}
790799
791- template <uint8_t ns>
792- void interp_line (FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker, BIGINT i1,
793- const BIGINT N1)
800+ template <uint8_t ns, class simd_type >
801+ void interp_line (FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
802+ const BIGINT i1, const BIGINT N1)
794803/* 1D interpolate complex values from size-ns block of the du (uniform grid
795804 data) array to a single complex output value "target", using as weights the
796805 1d kernel evaluation list ker1.
@@ -808,38 +817,60 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker, BI
808817 Barnett 6/16/17.
809818*/
810819{
820+
811821 FLT out[] = {0.0 , 0.0 };
812822 BIGINT j = i1;
813- if (i1 < 0 ) { // wraps at left
823+ if (FINUFFT_UNLIKELY ( i1 < 0 ) ) { // wraps at left
814824 j += N1;
815- for (int dx = 0 ; dx < -i1; ++dx) {
825+ for (UBIGINT dx = 0 ; dx < -i1; ++dx) {
816826 out[0 ] += du[2 * j] * ker[dx];
817827 out[1 ] += du[2 * j + 1 ] * ker[dx];
818828 ++j;
819829 }
820830 j -= N1;
821- for (int dx = -i1; dx < ns; ++dx) {
831+ for (UBIGINT dx = -i1; dx < ns; ++dx) {
822832 out[0 ] += du[2 * j] * ker[dx];
823833 out[1 ] += du[2 * j + 1 ] * ker[dx];
824834 ++j;
825835 }
826- } else if (i1 + ns >= N1) { // wraps at right
836+ } else if (FINUFFT_UNLIKELY ( i1 + ns >= N1) ) { // wraps at right
827837 for (int dx = 0 ; dx < N1 - i1; ++dx) {
828838 out[0 ] += du[2 * j] * ker[dx];
829839 out[1 ] += du[2 * j + 1 ] * ker[dx];
830840 ++j;
831841 }
832842 j -= N1;
833- for (int dx = N1 - i1; dx < ns; ++dx) {
843+ for (UBIGINT dx = N1 - i1; dx < ns; ++dx) {
834844 out[0 ] += du[2 * j] * ker[dx];
835845 out[1 ] += du[2 * j + 1 ] * ker[dx];
836846 ++j;
837847 }
838848 } else { // doesn't wrap
839- for (int dx = 0 ; dx < ns; ++dx) {
840- out[0 ] += du[2 * j] * ker[dx];
841- out[1 ] += du[2 * j + 1 ] * ker[dx];
842- ++j;
849+ using arch_t = typename simd_type::arch_type;
850+ static constexpr auto padding = get_padding<FLT, 2 * ns>();
851+ static constexpr auto alignment = arch_t::alignment ();
852+ static constexpr auto simd_size = simd_type::size;
853+ static constexpr auto regular_part = (2 * ns + padding) & (-(2 * simd_size));
854+ simd_type res{0 };
855+ for (uint8_t dx{0 }; dx < regular_part; dx += 2 * simd_size) {
856+ const auto ker_v = simd_type::load_aligned (ker + dx / 2 );
857+ const auto du_pt0 = simd_type::load_unaligned (du + dx);
858+ const auto du_pt1 = simd_type::load_unaligned (du + dx + simd_size);
859+ const auto ker0low = xsimd::swizzle (ker_v, zip_low_index<arch_t >);
860+ const auto ker0hi = xsimd::swizzle (ker_v, zip_hi_index<arch_t >);
861+ res = xsimd::fma (ker0low, du_pt0, xsimd::fma (ker0hi, du_pt1, res));
862+ }
863+ if constexpr (regular_part < 2 * ns) {
864+ const auto ker0 = simd_type::load_unaligned (ker + (regular_part / 2 ));
865+ const auto du_pt = simd_type::load_unaligned (du + regular_part);
866+ const auto ker0low = xsimd::swizzle (ker0, zip_low_index<arch_t >);
867+ res = xsimd::fma (ker0low, du_pt, res);
868+ }
869+ alignas (alignment) std::array<FLT, simd_size> res_array{};
870+ res.store_aligned (res_array.data ());
871+ for (uint8_t i{0 }; i < simd_size; i += 2 ) {
872+ out[0 ] += res_array[i];
873+ out[1 ] += res_array[i + 1 ];
843874 }
844875 }
845876 target[0 ] = out[0 ];
@@ -1929,6 +1960,13 @@ struct zip_hi {
19291960 }
19301961};
19311962
1963+ struct select_even {
1964+ static constexpr bool get (unsigned index, unsigned /* size*/ ) { return index % 2 == 0 ; }
1965+ };
1966+ struct select_odd {
1967+ static constexpr bool get (unsigned index, unsigned /* size*/ ) { return index % 2 == 1 ; }
1968+ };
1969+
19321970void print_subgrid_info (int ndims, BIGINT offset1, BIGINT offset2, BIGINT offset3,
19331971 BIGINT padded_size1, BIGINT size1, BIGINT size2, BIGINT size3,
19341972 BIGINT M0) {
0 commit comments