Skip to content

Commit 5b9731f

Browse files
committed
Vectorized 1D and 2D
1 parent 88129a4 commit 5b9731f

File tree

1 file changed

+66
-31
lines changed

1 file changed

+66
-31
lines changed

src/spreadinterp.cpp

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ static FINUFFT_ALWAYS_INLINE void eval_kernel_vec_Horner(
7878
template<uint8_t ns, class simd_type = PaddedSIMD<FLT, 2 * ns>>
7979
static void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
8080
BIGINT i1, BIGINT N1);
81-
template<uint8_t ns>
81+
template<uint8_t ns, class simd_type = PaddedSIMD<FLT, 2 * ns>>
8282
static void interp_square(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
8383
const FLT *ker2, BIGINT i1, BIGINT i2, BIGINT N1, BIGINT N2);
8484
template<uint8_t ns>
@@ -534,7 +534,8 @@ FINUFFT_NEVER_INLINE static int interpSorted_kernel(
534534
break;
535535
case 2:
536536
ker_eval<ns, kerevalmeth, FLT, simd_type>(kernel_values.data(), opts, x1, x2);
537-
interp_square<ns>(target, data_uniform, ker1, ker2, i1, i2, N1, N2);
537+
interp_square<ns, simd_type>(target, data_uniform, ker1, ker2, i1, i2, N1,
538+
N2);
538539
break;
539540
case 3:
540541
ker_eval<ns, kerevalmeth, FLT, simd_type>(kernel_values.data(), opts, x1, x2,
@@ -819,6 +820,7 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
819820
{
820821
std::array<FLT, 2> out{0};
821822
BIGINT j = i1;
823+
// removing the wrapping leads up to 10% speedup in certain cases
822824
if (FINUFFT_UNLIKELY(i1 < 0)) { // wraps at left
823825
j += N1;
824826
for (UBIGINT dx = 0; dx < -i1; ++dx, ++j) {
@@ -846,28 +848,30 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
846848
static constexpr auto alignment = arch_t::alignment();
847849
static constexpr auto simd_size = simd_type::size;
848850
static constexpr auto regular_part = (2 * ns + padding) & (-(2 * simd_size));
851+
const auto du_ptr = du + 2 * j;
849852
simd_type res_low{0}, res_hi{0};
850853
for (uint8_t dx{0}; dx < regular_part; dx += 2 * simd_size) {
851854
const auto ker_v = simd_type::load_aligned(ker + dx / 2);
852-
const auto du_pt0 = simd_type::load_unaligned(du + dx);
853-
const auto du_pt1 = simd_type::load_unaligned(du + dx + simd_size);
855+
const auto du_pt0 = simd_type::load_unaligned(du_ptr + dx);
856+
const auto du_pt1 = simd_type::load_unaligned(du_ptr + dx + simd_size);
854857
const auto ker0low = xsimd::swizzle(ker_v, zip_low_index<arch_t>);
855858
const auto ker0hi = xsimd::swizzle(ker_v, zip_hi_index<arch_t>);
856859
res_low = xsimd::fma(ker0low, du_pt0, res_low);
857860
res_hi = xsimd::fma(ker0hi, du_pt1, res_hi);
858861
}
859862
if constexpr (regular_part < 2 * ns) {
860863
const auto ker0 = simd_type::load_unaligned(ker + (regular_part / 2));
861-
const auto du_pt = simd_type::load_unaligned(du + regular_part);
864+
const auto du_pt = simd_type::load_unaligned(du_ptr + regular_part);
862865
const auto ker0low = xsimd::swizzle(ker0, zip_low_index<arch_t>);
863866
res_low = xsimd::fma(ker0low, du_pt, res_low);
864867
}
865868
// This is slower than summing and looping
866-
// const auto res_real = xsimd::shuffle(res_low, res_hi,
867-
// select_even_mask<arch_t>); const auto res_imag = xsimd::shuffle(res_low,
868-
// res_hi, select_odd_mask<arch_t>); out[0] = xsimd::reduce_add(res_real); out[1]
869-
// = xsimd::reduce_add(res_imag);
870-
869+
// clang-format off
870+
// const auto res_real = xsimd::shuffle(res_low, res_hi, select_even_mask<arch_t>);
871+
// const auto res_imag = xsimd::shuffle(res_low, res_hi, select_odd_mask<arch_t>);
872+
// out[0] = xsimd::reduce_add(res_real);
873+
// out[1] = xsimd::reduce_add(res_imag);
874+
// clang-format on
871875
const auto res = res_low + res_hi;
872876
alignas(alignment) std::array<FLT, simd_size> res_array{};
873877
res.store_aligned(res_array.data());
@@ -880,7 +884,7 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
880884
target[1] = out[1];
881885
}
882886

883-
template<uint8_t ns>
887+
template<uint8_t ns, class simd_type>
884888
void interp_square(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
885889
const FLT *ker2, const BIGINT i1, const BIGINT i2, const BIGINT N1,
886890
const BIGINT N2)
@@ -914,42 +918,73 @@ void interp_square(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
914918
{
915919
FLT out[] = {0.0, 0.0};
916920
if (i1 >= 0 && i1 + ns <= N1 && i2 >= 0 && i2 + ns <= N2) { // no wrapping: avoid ptrs
917-
FLT line[2 * MAX_NSPREAD]; // store a horiz line (interleaved real,imag)
918-
// block for first y line, to avoid explicitly initializing line with zeros
919-
{
920-
const FLT *lptr = du + 2 * (N1 * i2 + i1); // ptr to horiz line start in du
921-
for (int l = 0; l < 2 * ns; l++) { // l is like dx but for ns interleaved
922-
line[l] = ker2[0] * lptr[l];
921+
using arch_t = typename simd_type::arch_type;
922+
static constexpr auto padding = get_padding<FLT, 2 * ns>();
923+
static constexpr auto alignment = arch_t::alignment();
924+
static constexpr auto simd_size = simd_type::size;
925+
static constexpr uint8_t regular_part = (2 * ns + padding) & (-(2 * simd_size));
926+
static constexpr uint8_t line_vectors = (2 * ns + padding) / simd_size;
927+
const auto line = [du, N1, i2, i1, ker2]() {
928+
std::array<simd_type, line_vectors> line{};
929+
// block for first y line, to avoid explicitly initializing line with zeros
930+
{
931+
const auto l_ptr = du + 2 * (N1 * i2 + i1); // ptr to horiz line start in du
932+
const simd_type ker2_v{ker2[0]};
933+
for (uint8_t l{0}; l < line_vectors; ++l) {
934+
// l is like dx but for ns interleaved
935+
line[l] = ker2_v * simd_type::load_unaligned(l * simd_size + l_ptr);
936+
}
923937
}
924-
}
925-
// add remaining const-y lines to the line (expensive inner loop)
926-
for (int dy = 1; dy < ns; dy++) {
927-
const FLT *lptr = du + 2 * (N1 * (i2 + dy) + i1); // (see above)
928-
for (int l = 0; l < 2 * ns; ++l) {
929-
line[l] += ker2[dy] * lptr[l];
938+
// add remaining const-y lines to the line (expensive inner loop)
939+
for (uint8_t dy{1}; dy < ns; dy++) {
940+
const auto l_ptr = du + 2 * (N1 * (i2 + dy) + i1); // (see above)
941+
const simd_type ker2_v{ker2[dy]};
942+
for (uint8_t l{0}; l < line_vectors; ++l) {
943+
line[l] = xsimd::fma(ker2_v, simd_type::load_unaligned(l * simd_size + l_ptr),
944+
line[l]);
945+
}
930946
}
931-
}
947+
return line;
948+
}();
932949
// apply x kernel to the (interleaved) line and add together
933-
for (int dx = 0; dx < ns; dx++) {
934-
out[0] += line[2 * dx] * ker1[dx];
935-
out[1] += line[2 * dx + 1] * ker1[dx];
950+
simd_type res_low{0}, res_hi{0};
951+
for (uint8_t i = 0; i < (line_vectors & ~1); // NOLINT(*-too-small-loop-variable)
952+
i += 2) {
953+
const auto ker1_v = simd_type::load_aligned(ker1 + i * simd_size / 2);
954+
const auto ker1low = xsimd::swizzle(ker1_v, zip_low_index<arch_t>);
955+
const auto ker1hi = xsimd::swizzle(ker1_v, zip_hi_index<arch_t>);
956+
res_low = xsimd::fma(ker1low, line[i], res_low);
957+
res_hi = xsimd::fma(ker1hi, line[i + 1], res_hi);
958+
}
959+
if constexpr (line_vectors % 2) {
960+
const auto ker1_v =
961+
simd_type::load_aligned(ker1 + (line_vectors - 1) * simd_size / 2);
962+
const auto ker1low = xsimd::swizzle(ker1_v, zip_low_index<arch_t>);
963+
res_low = xsimd::fma(ker1low, line.back(), res_low);
964+
}
965+
const auto res = res_low + res_hi;
966+
alignas(alignment) std::array<FLT, simd_size> res_array{};
967+
res.store_aligned(res_array.data());
968+
for (uint8_t i{0}; i < simd_size; i += 2) {
969+
out[0] += res_array[i];
970+
out[1] += res_array[i + 1];
936971
}
937972
} else { // wraps somewhere: use ptr list
938973
// this is slower than above, but occurs much less often, with fractional
939974
// rate O(ns/min(N1,N2)). Thus this code doesn't need to be so optimized.
940975
BIGINT j1[MAX_NSPREAD], j2[MAX_NSPREAD]; // 1d ptr lists
941976
BIGINT x = i1, y = i2; // initialize coords
942-
for (int d = 0; d < ns; d++) { // set up ptr lists
977+
for (uint8_t d{0}; d < ns; d++) { // set up ptr lists
943978
if (x < 0) x += N1;
944979
if (x >= N1) x -= N1;
945980
j1[d] = x++;
946981
if (y < 0) y += N2;
947982
if (y >= N2) y -= N2;
948983
j2[d] = y++;
949984
}
950-
for (int dy = 0; dy < ns; dy++) { // use the pts lists
951-
BIGINT oy = N1 * j2[dy]; // offset due to y
952-
for (int dx = 0; dx < ns; dx++) {
985+
for (uint8_t dy{0}; dy < ns; dy++) { // use the pts lists
986+
BIGINT oy = N1 * j2[dy]; // offset due to y
987+
for (uint8_t dx{0}; dx < ns; dx++) {
953988
FLT k = ker1[dx] * ker2[dy];
954989
BIGINT j = oy + j1[dx];
955990
out[0] += du[2 * j] * k;

0 commit comments

Comments
 (0)