@@ -78,7 +78,7 @@ static FINUFFT_ALWAYS_INLINE void eval_kernel_vec_Horner(
7878template <uint8_t ns, class simd_type = PaddedSIMD<FLT, 2 * ns>>
7979static void interp_line (FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
8080 BIGINT i1, BIGINT N1);
81- template <uint8_t ns>
81+ template <uint8_t ns, class simd_type = PaddedSIMD<FLT, 2 * ns> >
8282static void interp_square (FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
8383 const FLT *ker2, BIGINT i1, BIGINT i2, BIGINT N1, BIGINT N2);
8484template <uint8_t ns>
@@ -534,7 +534,8 @@ FINUFFT_NEVER_INLINE static int interpSorted_kernel(
534534 break ;
535535 case 2 :
536536 ker_eval<ns, kerevalmeth, FLT, simd_type>(kernel_values.data (), opts, x1, x2);
537- interp_square<ns>(target, data_uniform, ker1, ker2, i1, i2, N1, N2);
537+ interp_square<ns, simd_type>(target, data_uniform, ker1, ker2, i1, i2, N1,
538+ N2);
538539 break ;
539540 case 3 :
540541 ker_eval<ns, kerevalmeth, FLT, simd_type>(kernel_values.data (), opts, x1, x2,
@@ -819,6 +820,7 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
819820{
820821 std::array<FLT, 2 > out{0 };
821822 BIGINT j = i1;
823+ // removing the wrapping leads up to 10% speedup in certain cases
822824 if (FINUFFT_UNLIKELY (i1 < 0 )) { // wraps at left
823825 j += N1;
824826 for (UBIGINT dx = 0 ; dx < -i1; ++dx, ++j) {
@@ -846,28 +848,30 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
846848 static constexpr auto alignment = arch_t::alignment ();
847849 static constexpr auto simd_size = simd_type::size;
848850 static constexpr auto regular_part = (2 * ns + padding) & (-(2 * simd_size));
851+ const auto du_ptr = du + 2 * j;
849852 simd_type res_low{0 }, res_hi{0 };
850853 for (uint8_t dx{0 }; dx < regular_part; dx += 2 * simd_size) {
851854 const auto ker_v = simd_type::load_aligned (ker + dx / 2 );
852- const auto du_pt0 = simd_type::load_unaligned (du + dx);
853- const auto du_pt1 = simd_type::load_unaligned (du + dx + simd_size);
855+ const auto du_pt0 = simd_type::load_unaligned (du_ptr + dx);
856+ const auto du_pt1 = simd_type::load_unaligned (du_ptr + dx + simd_size);
854857 const auto ker0low = xsimd::swizzle (ker_v, zip_low_index<arch_t >);
855858 const auto ker0hi = xsimd::swizzle (ker_v, zip_hi_index<arch_t >);
856859 res_low = xsimd::fma (ker0low, du_pt0, res_low);
857860 res_hi = xsimd::fma (ker0hi, du_pt1, res_hi);
858861 }
859862 if constexpr (regular_part < 2 * ns) {
860863 const auto ker0 = simd_type::load_unaligned (ker + (regular_part / 2 ));
861- const auto du_pt = simd_type::load_unaligned (du + regular_part);
864+ const auto du_pt = simd_type::load_unaligned (du_ptr + regular_part);
862865 const auto ker0low = xsimd::swizzle (ker0, zip_low_index<arch_t >);
863866 res_low = xsimd::fma (ker0low, du_pt, res_low);
864867 }
865868 // This is slower than summing and looping
866- // const auto res_real = xsimd::shuffle(res_low, res_hi,
867- // select_even_mask<arch_t>); const auto res_imag = xsimd::shuffle(res_low,
868- // res_hi, select_odd_mask<arch_t>); out[0] = xsimd::reduce_add(res_real); out[1]
869- // = xsimd::reduce_add(res_imag);
870-
869+ // clang-format off
870+ // const auto res_real = xsimd::shuffle(res_low, res_hi, select_even_mask<arch_t>);
871+ // const auto res_imag = xsimd::shuffle(res_low, res_hi, select_odd_mask<arch_t>);
872+ // out[0] = xsimd::reduce_add(res_real);
873+ // out[1] = xsimd::reduce_add(res_imag);
874+ // clang-format on
871875 const auto res = res_low + res_hi;
872876 alignas (alignment) std::array<FLT, simd_size> res_array{};
873877 res.store_aligned (res_array.data ());
@@ -880,7 +884,7 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
880884 target[1 ] = out[1 ];
881885}
882886
883- template <uint8_t ns>
887+ template <uint8_t ns, class simd_type >
884888void interp_square (FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
885889 const FLT *ker2, const BIGINT i1, const BIGINT i2, const BIGINT N1,
886890 const BIGINT N2)
@@ -914,42 +918,73 @@ void interp_square(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
914918{
915919 FLT out[] = {0.0 , 0.0 };
916920 if (i1 >= 0 && i1 + ns <= N1 && i2 >= 0 && i2 + ns <= N2) { // no wrapping: avoid ptrs
917- FLT line[2 * MAX_NSPREAD]; // store a horiz line (interleaved real,imag)
918- // block for first y line, to avoid explicitly initializing line with zeros
919- {
920- const FLT *lptr = du + 2 * (N1 * i2 + i1); // ptr to horiz line start in du
921- for (int l = 0 ; l < 2 * ns; l++) { // l is like dx but for ns interleaved
922- line[l] = ker2[0 ] * lptr[l];
921+ using arch_t = typename simd_type::arch_type;
922+ static constexpr auto padding = get_padding<FLT, 2 * ns>();
923+ static constexpr auto alignment = arch_t::alignment ();
924+ static constexpr auto simd_size = simd_type::size;
925+ static constexpr uint8_t regular_part = (2 * ns + padding) & (-(2 * simd_size));
926+ static constexpr uint8_t line_vectors = (2 * ns + padding) / simd_size;
927+ const auto line = [du, N1, i2, i1, ker2]() {
928+ std::array<simd_type, line_vectors> line{};
929+ // block for first y line, to avoid explicitly initializing line with zeros
930+ {
931+ const auto l_ptr = du + 2 * (N1 * i2 + i1); // ptr to horiz line start in du
932+ const simd_type ker2_v{ker2[0 ]};
933+ for (uint8_t l{0 }; l < line_vectors; ++l) {
934+ // l is like dx but for ns interleaved
935+ line[l] = ker2_v * simd_type::load_unaligned (l * simd_size + l_ptr);
936+ }
923937 }
924- }
925- // add remaining const-y lines to the line (expensive inner loop)
926- for (int dy = 1 ; dy < ns; dy++) {
927- const FLT *lptr = du + 2 * (N1 * (i2 + dy) + i1); // (see above)
928- for (int l = 0 ; l < 2 * ns; ++l) {
929- line[l] += ker2[dy] * lptr[l];
938+ // add remaining const-y lines to the line (expensive inner loop)
939+ for (uint8_t dy{1 }; dy < ns; dy++) {
940+ const auto l_ptr = du + 2 * (N1 * (i2 + dy) + i1); // (see above)
941+ const simd_type ker2_v{ker2[dy]};
942+ for (uint8_t l{0 }; l < line_vectors; ++l) {
943+ line[l] = xsimd::fma (ker2_v, simd_type::load_unaligned (l * simd_size + l_ptr),
944+ line[l]);
945+ }
930946 }
931- }
947+ return line;
948+ }();
932949 // apply x kernel to the (interleaved) line and add together
933- for (int dx = 0 ; dx < ns; dx++) {
934- out[0 ] += line[2 * dx] * ker1[dx];
935- out[1 ] += line[2 * dx + 1 ] * ker1[dx];
950+ simd_type res_low{0 }, res_hi{0 };
951+ for (uint8_t i = 0 ; i < (line_vectors & ~1 ); // NOLINT(*-too-small-loop-variable)
952+ i += 2 ) {
953+ const auto ker1_v = simd_type::load_aligned (ker1 + i * simd_size / 2 );
954+ const auto ker1low = xsimd::swizzle (ker1_v, zip_low_index<arch_t >);
955+ const auto ker1hi = xsimd::swizzle (ker1_v, zip_hi_index<arch_t >);
956+ res_low = xsimd::fma (ker1low, line[i], res_low);
957+ res_hi = xsimd::fma (ker1hi, line[i + 1 ], res_hi);
958+ }
959+ if constexpr (line_vectors % 2 ) {
960+ const auto ker1_v =
961+ simd_type::load_aligned (ker1 + (line_vectors - 1 ) * simd_size / 2 );
962+ const auto ker1low = xsimd::swizzle (ker1_v, zip_low_index<arch_t >);
963+ res_low = xsimd::fma (ker1low, line.back (), res_low);
964+ }
965+ const auto res = res_low + res_hi;
966+ alignas (alignment) std::array<FLT, simd_size> res_array{};
967+ res.store_aligned (res_array.data ());
968+ for (uint8_t i{0 }; i < simd_size; i += 2 ) {
969+ out[0 ] += res_array[i];
970+ out[1 ] += res_array[i + 1 ];
936971 }
937972 } else { // wraps somewhere: use ptr list
938973 // this is slower than above, but occurs much less often, with fractional
939974 // rate O(ns/min(N1,N2)). Thus this code doesn't need to be so optimized.
940975 BIGINT j1[MAX_NSPREAD], j2[MAX_NSPREAD]; // 1d ptr lists
941976 BIGINT x = i1, y = i2; // initialize coords
942- for (int d = 0 ; d < ns; d++) { // set up ptr lists
977+ for (uint8_t d{ 0 } ; d < ns; d++) { // set up ptr lists
943978 if (x < 0 ) x += N1;
944979 if (x >= N1) x -= N1;
945980 j1[d] = x++;
946981 if (y < 0 ) y += N2;
947982 if (y >= N2) y -= N2;
948983 j2[d] = y++;
949984 }
950- for (int dy = 0 ; dy < ns; dy++) { // use the pts lists
951- BIGINT oy = N1 * j2[dy]; // offset due to y
952- for (int dx = 0 ; dx < ns; dx++) {
985+ for (uint8_t dy{ 0 } ; dy < ns; dy++) { // use the pts lists
986+ BIGINT oy = N1 * j2[dy]; // offset due to y
987+ for (uint8_t dx{ 0 } ; dx < ns; dx++) {
953988 FLT k = ker1[dx] * ker2[dy];
954989 BIGINT j = oy + j1[dx];
955990 out[0 ] += du[2 * j] * k;
0 commit comments