Skip to content

Commit f5e71c8

Browse files
committed
commented-spreader
1 parent a35199f commit f5e71c8

File tree

1 file changed

+100
-43
lines changed

1 file changed

+100
-43
lines changed

src/spreadinterp.cpp

Lines changed: 100 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ static FINUFFT_ALWAYS_INLINE auto ker_eval(FLT *FINUFFT_RESTRICT ker,
6767
const finufft_spread_opts &opts,
6868
const V... elems) noexcept;
6969
static FINUFFT_ALWAYS_INLINE FLT fold_rescale(FLT x, BIGINT N) noexcept;
70+
template<class simd_type>
71+
static simd_type fold_rescale(const simd_type &x, const BIGINT N) noexcept;
7072
static FINUFFT_ALWAYS_INLINE void set_kernel_args(
7173
FLT *args, FLT x, const finufft_spread_opts &opts) noexcept;
7274
static FINUFFT_ALWAYS_INLINE void evaluate_kernel_vector(
@@ -389,7 +391,7 @@ int spreadSorted(const BIGINT *sort_indices, BIGINT N1, BIGINT N2, BIGINT N3,
389391
#pragma omp parallel num_threads(nthr)
390392
{
391393
// local copies of NU pts and data for each subproblem
392-
std::vector<FLT> kx0{0}, ky0{0}, kz0{0}, dd0{0}, du0{0};
394+
std::vector<FLT> kx0{}, ky0{}, kz0{}, dd0{}, du0{};
393395
#pragma omp for schedule(dynamic, 1) // each is big
394396
for (int isub = 0; isub < nb; isub++) { // Main loop through the subproblems
395397
const BIGINT M0 = brk[isub + 1] - brk[isub]; // # NU pts in this subproblem
@@ -482,11 +484,10 @@ FINUFFT_NEVER_INLINE static int interpSorted_kernel(
482484
timer.start();
483485
#pragma omp parallel num_threads(nthr)
484486
{
485-
static constexpr auto CHUNKSIZE = 16; // Chunks of Type 2 targets (Ludvig found by
486-
// expt)
487-
BIGINT jlist[CHUNKSIZE];
488-
FLT xjlist[CHUNKSIZE], yjlist[CHUNKSIZE], zjlist[CHUNKSIZE];
489-
FLT outbuf[2 * CHUNKSIZE];
487+
static constexpr auto CHUNKSIZE = simd_size; // number of targets per chunk
488+
alignas(alignment) UBIGINT jlist[CHUNKSIZE];
489+
alignas(alignment) FLT xjlist[CHUNKSIZE], yjlist[CHUNKSIZE], zjlist[CHUNKSIZE];
490+
alignas(alignment) FLT outbuf[2 * CHUNKSIZE];
490491
// Kernels: static alloc is faster, so we do it for up to 3D...
491492
alignas(alignment) std::array<FLT, 3 * MAX_NSPREAD> kernel_values{0};
492493
auto *FINUFFT_RESTRICT ker1 = kernel_values.data();
@@ -497,7 +498,6 @@ FINUFFT_NEVER_INLINE static int interpSorted_kernel(
497498
#pragma omp for schedule(dynamic, 1000) // assign threads to NU targ pts:
498499
for (BIGINT i = 0; i < M; i += CHUNKSIZE) // main loop over NU trgs, interp each from
499500
// U
500-
501501
{
502502
// Setup buffers for this chunk
503503
const int bufsize = (i + CHUNKSIZE > M) ? M - i : CHUNKSIZE;
@@ -518,14 +518,16 @@ FINUFFT_NEVER_INLINE static int interpSorted_kernel(
518518
auto *FINUFFT_RESTRICT target = outbuf + 2 * ibuf;
519519

520520
// coords (x,y,z), spread block corner index (i1,i2,i3) of current NU targ
521-
const auto i1 = (BIGINT)std::ceil(xj - ns2); // leftmost grid index
522-
const auto i2 = (ndims > 1) ? (BIGINT)std::ceil(yj - ns2) : 0; // min y grid index
523-
const auto i3 = (ndims > 2) ? (BIGINT)std::ceil(zj - ns2) : 0; // min z grid index
521+
const auto i1 = BIGINT(std::ceil(xj - ns2)); // leftmost grid index
522+
const auto i2 = (ndims > 1) ? BIGINT(std::ceil(yj - ns2)) : 0; // min y grid index
523+
const auto i3 = (ndims > 2) ? BIGINT(std::ceil(zj - ns2)) : 0; // min z grid index
524524

525-
const auto x1 = (FLT)i1 - xj; // shift of ker center, in [-w/2,-w/2+1]
526-
const auto x2 = (ndims > 1) ? (FLT)i2 - yj : 0;
527-
const auto x3 = (ndims > 2) ? (FLT)i3 - zj : 0;
525+
const auto x1 = std::ceil(xj - ns2) - xj; // shift of ker center, in [-w/2,-w/2+1]
526+
const auto x2 = (ndims > 1) ? std::ceil(yj - ns2) - yj : 0;
527+
const auto x3 = (ndims > 2) ? std::ceil(zj - ns2) - zj : 0;
528528

529+
ker_eval<ns, kerevalmeth, FLT, simd_type>(kernel_values.data(), opts, x1);
530+
interp_line<ns, simd_type>(target, data_uniform, ker1, i1, N1);
529531
// eval kernel values patch and use to interpolate from uniform data...
530532
if (!(opts.flags & TF_OMIT_SPREADING)) {
531533
switch (ndims) {
@@ -800,37 +802,42 @@ Two upsampfacs implemented. Params must match ref formula. Barnett 4/24/18 */
800802
}
801803

802804
template<uint8_t ns>
803-
void interp_line_wrap(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
804-
const BIGINT i1, const BIGINT N1) {
805+
static void interp_line_wrap(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
806+
const BIGINT i1, const BIGINT N1) {
807+
/* This function is called when the kernel wraps around the grid. It is
808+
slower than interp_line.
809+
M. Barbone July 2024: - moved the logic to a separate function
810+
- using fused multiply-add (fma) for better performance
811+
*/
805812
std::array<FLT, 2> out{0};
806813
BIGINT j = i1;
807814
if (i1 < 0) { // wraps at left
808815
j += N1;
809816
for (uint8_t dx = 0; dx < -i1; ++dx, ++j) {
810-
out[0] = xsimd::fma(du[2 * j], ker[dx], out[0]);
811-
out[1] = xsimd::fma(du[2 * j + 1], ker[dx], out[1]);
817+
out[0] = std::fma(du[2 * j], ker[dx], out[0]);
818+
out[1] = std::fma(du[2 * j + 1], ker[dx], out[1]);
812819
}
813820
j -= N1;
814821
for (uint8_t dx = -i1; dx < ns; ++dx, ++j) {
815-
out[0] = xsimd::fma(du[2 * j], ker[dx], out[0]);
816-
out[1] = xsimd::fma(du[2 * j + 1], ker[dx], out[1]);
822+
out[0] = std::fma(du[2 * j], ker[dx], out[0]);
823+
out[1] = std::fma(du[2 * j + 1], ker[dx], out[1]);
817824
}
818825
} else if (i1 + ns >= N1) { // wraps at right
819826
for (uint8_t dx = 0; dx < N1 - i1; ++dx, ++j) {
820-
out[0] = xsimd::fma(du[2 * j], ker[dx], out[0]);
821-
out[1] = xsimd::fma(du[2 * j + 1], ker[dx], out[1]);
827+
out[0] = std::fma(du[2 * j], ker[dx], out[0]);
828+
out[1] = std::fma(du[2 * j + 1], ker[dx], out[1]);
822829
}
823830
j -= N1;
824831
for (uint8_t dx = N1 - i1; dx < ns; ++dx, ++j) {
825-
out[0] = xsimd::fma(du[2 * j], ker[dx], out[0]);
826-
out[1] = xsimd::fma(du[2 * j + 1], ker[dx], out[1]);
832+
out[0] = std::fma(du[2 * j], ker[dx], out[0]);
833+
out[1] = std::fma(du[2 * j + 1], ker[dx], out[1]);
827834
}
828835
} else {
829836
// padding is okay for ker, but it might spill over du array
830837
// so this checks for that case and does not explicitly vectorize
831838
for (uint8_t dx = 0; dx < ns; ++dx, ++j) {
832-
out[0] = xsimd::fma(du[2 * j], ker[dx], out[0]);
833-
out[1] = xsimd::fma(du[2 * j + 1], ker[dx], out[1]);
839+
out[0] = std::fma(du[2 * j], ker[dx], out[0]);
840+
out[1] = std::fma(du[2 * j + 1], ker[dx], out[1]);
834841
}
835842
}
836843
target[0] = out[0];
@@ -839,8 +846,8 @@ void interp_line_wrap(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ke
839846

840847
template<uint8_t ns, class simd_type>
841848
void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
842-
const BIGINT i1, const BIGINT N1)
843-
/* 1D interpolate complex values from size-ns block of the du (uniform grid
849+
const BIGINT i1, const BIGINT N1) {
850+
/* 1D interpolate complex values from size-ns block of the du (uniform grid
844851
data) array to a single complex output value "target", using as weights the
845852
1d kernel evaluation list ker1.
846853
Inputs:
@@ -855,8 +862,10 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
855862
Periodic wrapping in the du array is applied, assuming N1>=ns.
856863
Internally, dx indices into ker array j is index in complex du array.
857864
Barnett 6/16/17.
865+
M. Barbone July 2024: - moved wrapping logic to interp_line_wrap
866+
- using explicit SIMD vectorization to overcome the out[2] array
867+
limitation
858868
*/
859-
{
860869
using arch_t = typename simd_type::arch_type;
861870
static constexpr auto padding = get_padding<FLT, 2 * ns>();
862871
static constexpr auto alignment = arch_t::alignment();
@@ -865,9 +874,12 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
865874
std::array<FLT, 2> out{0};
866875
const auto j = i1;
867876
// removing the wrapping leads up to 10% speedup in certain cases
877+
// moved the wrapping to another function to reduce instruction cache pressure
868878
if (i1 < 0 || i1 + ns >= N1 || i1 + ns + (padding + 1) / 2 >= N1) {
869879
return interp_line_wrap<ns>(target, du, ker, i1, N1);
870880
} else { // doesn't wrap
881+
// logic largely similar to spread 1D kernel, please see the explanation there
882+
// for the first part of this code
871883
const auto du_ptr = du + 2 * j;
872884
simd_type res_low{0}, res_hi{0};
873885
for (uint8_t dx{0}; dx < regular_part; dx += 2 * simd_size) {
@@ -885,13 +897,24 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
885897
const auto ker0low = xsimd::swizzle(ker0, zip_low_index<arch_t>);
886898
res_low = xsimd::fma(ker0low, du_pt, res_low);
887899
}
888-
// This is slower than summing and looping
900+
901+
// this is where the code differs from spread_kernel, the interpolator does an extra
902+
// reduction step to SIMD elements down to 2 elements
903+
// This is known as horizontal sum in SIMD terminology
904+
905+
// This does a horizontal sum using vector instruction,
906+
// is slower than summing and looping
889907
// clang-format off
890908
// const auto res_real = xsimd::shuffle(res_low, res_hi, select_even_mask<arch_t>);
891909
// const auto res_imag = xsimd::shuffle(res_low, res_hi, select_odd_mask<arch_t>);
892910
// out[0] = xsimd::reduce_add(res_real);
893911
// out[1] = xsimd::reduce_add(res_imag);
894912
// clang-format on
913+
914+
// This does a horizontal sum using a loop instead of relying on SIMD instructions
915+
// this is faster than the above code but less elegant.
916+
// lambdas here to limit the scope of temporary variables and have the compiler
917+
// optimize the code better
895918
alignas(alignment) const auto res_array = [](const auto &res_low,
896919
const auto &res_hi) constexpr noexcept {
897920
alignas(alignment) std::array<FLT, simd_size> res_array{};
@@ -909,9 +932,15 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
909932
}
910933

911934
template<uint8_t ns, class simd_type>
912-
void interp_square_wrap(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
913-
const FLT *ker2, const BIGINT i1, const BIGINT i2,
914-
const BIGINT N1, const BIGINT N2) {
935+
static void interp_square_wrap(FLT *FINUFFT_RESTRICT target, const FLT *du,
936+
const FLT *ker1, const FLT *ker2, const BIGINT i1,
937+
const BIGINT i2, const BIGINT N1, const BIGINT N2) {
938+
/*
939+
* This function is called when the kernel wraps around the grid. It is slower than
940+
* the non wrapping version.
941+
* There is an extra case for when ker is padded and spills over the du array.
942+
* In this case uses the old non wrapping version.
943+
*/
915944
std::array<FLT, 2> out{0};
916945
using arch_t = typename simd_type::arch_type;
917946
static constexpr auto padding = get_padding<FLT, 2 * ns>();
@@ -930,13 +959,13 @@ void interp_square_wrap(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *
930959
for (uint8_t dy{1}; dy < ns; ++dy) {
931960
const auto *l_ptr = du + 2 * (N1 * (i2 + dy) + i1); // (see above)
932961
for (uint8_t l{0}; l < 2 * ns; ++l) {
933-
line[l] = xsimd::fma(ker2[dy], l_ptr[l], line[l]);
962+
line[l] = std::fma(ker2[dy], l_ptr[l], line[l]);
934963
}
935964
}
936965
// apply x kernel to the (interleaved) line and add together
937966
for (uint8_t dx{0}; dx < ns; dx++) {
938-
out[0] = xsimd::fma(line[2 * dx], ker1[dx], out[0]);
939-
out[1] = xsimd::fma(line[2 * dx + 1], ker1[dx], out[1]);
967+
out[0] = std::fma(line[2 * dx], ker1[dx], out[0]);
968+
out[1] = std::fma(line[2 * dx + 1], ker1[dx], out[1]);
940969
}
941970
} else {
942971
std::array<UBIGINT, ns> j1{}, j2{}; // 1d ptr lists
@@ -993,6 +1022,10 @@ void interp_square(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
9931022
pretty badly. I think this is now much more analogous to the way the spread
9941023
operation is implemented, which has always been much faster when I tested
9951024
it."
1025+
M. Barbone July 2024: - moved the wrapping logic to interp_square_wrap
1026+
- using explicit SIMD vectorization to overcome the out[2] array
1027+
limitation
1028+
The code is largely similar to 1D interpolation, please see the explanation there
9961029
*/
9971030
{
9981031
std::array<FLT, 2> out{0};
@@ -1006,20 +1039,24 @@ void interp_square(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
10061039
if (i1 >= 0 && i1 + ns <= N1 && i2 >= 0 && i2 + ns <= N2 &&
10071040
(i1 + ns + (padding + 1) / 2 < N1)) {
10081041
const auto line = [du, N1, i2, i1, ker2]() constexpr noexcept {
1042+
// new array du_pts to store the du values for the current y line
10091043
std::array<simd_type, line_vectors> line{}, du_pts{};
10101044
// block for first y line, to avoid explicitly initializing line with zeros
10111045
{
10121046
const auto l_ptr = du + 2 * (N1 * i2 + i1); // ptr to horiz line start in du
10131047
const simd_type ker2_v{ker2[0]};
10141048
for (uint8_t l{0}; l < line_vectors; ++l) {
10151049
// l is like dx but for ns interleaved
1050+
// no fancy trick needed to multiply real,imag by ker2
10161051
line[l] = ker2_v * simd_type::load_unaligned(l * simd_size + l_ptr);
10171052
}
10181053
}
10191054
// add remaining const-y lines to the line (expensive inner loop)
10201055
for (uint8_t dy{1}; dy < ns; dy++) {
10211056
const auto l_ptr = du + 2 * (N1 * (i2 + dy) + i1); // (see above)
1057+
// vectorize over the fast axis of the du array
10221058
const simd_type ker2_v{ker2[dy]};
1059+
// First loop: Load all du_pt into the du_pts array
10231060
for (uint8_t l{0}; l < line_vectors; ++l) {
10241061
du_pts[l] = simd_type::load_unaligned(l * simd_size + l_ptr);
10251062
}
@@ -1030,7 +1067,9 @@ void interp_square(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
10301067
}
10311068
return line;
10321069
}();
1033-
const auto res = [ker1](const auto &line) {
1070+
// This is the same as 1D interpolation
1071+
// using lambda to limit the scope of the temporary variables
1072+
const auto res = [ker1](const auto &line) constexpr noexcept {
10341073
// apply x kernel to the (interleaved) line and add together
10351074
simd_type res_low{0}, res_hi{0};
10361075
for (uint8_t i = 0; i < (line_vectors & ~1); // NOLINT(*-too-small-loop-variable)
@@ -1065,10 +1104,14 @@ void interp_square(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
10651104
}
10661105

10671106
template<uint8_t ns, class simd_type>
1068-
void interp_cube_wrapped(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
1069-
const FLT *ker2, const FLT *ker3, const BIGINT i1,
1070-
const BIGINT i2, const BIGINT i3, const BIGINT N1,
1071-
const BIGINT N2, const BIGINT N3) {
1107+
static void interp_cube_wrapped(FLT *FINUFFT_RESTRICT target, const FLT *du,
1108+
const FLT *ker1, const FLT *ker2, const FLT *ker3,
1109+
const BIGINT i1, const BIGINT i2, const BIGINT i3,
1110+
const BIGINT N1, const BIGINT N2, const BIGINT N3) {
1111+
/*
1112+
* This function is called when the kernel wraps around the cube.
1113+
* Similarly to 2D and 1D wrapping, this is slower than the non wrapping version.
1114+
*/
10721115
using arch_t = typename simd_type::arch_type;
10731116
static constexpr auto padding = get_padding<FLT, 2 * ns>();
10741117
static constexpr auto alignment = arch_t::alignment();
@@ -1078,6 +1121,8 @@ void interp_cube_wrapped(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT
10781121
const auto in_bounds_2 = (i2 >= 0) & (i2 + ns <= N2);
10791122
const auto in_bounds_3 = (i3 >= 0) & (i3 + ns <= N3);
10801123
std::array<FLT, 2> out{0};
1124+
// case no wrapping needed but padding spills over du array.
1125+
// Hence, no explicit vectorization but the code is still faster
10811126
if (FINUFFT_LIKELY(in_bounds_1 && in_bounds_2 && in_bounds_3)) {
10821127
// no wrapping: avoid ptrs (by far the most common case)
10831128
// store a horiz line (interleaved real,imag)
@@ -1091,14 +1136,14 @@ void interp_cube_wrapped(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT
10911136
const auto l_ptr = du + 2 * (oz + N1 * (i2 + dy) + i1); // ptr start of line
10921137
const auto ker23 = ker2[dy] * ker3[dz];
10931138
for (uint8_t l{0}; l < 2 * ns; ++l) { // loop over ns interleaved (R,I) pairs
1094-
line[l] = xsimd::fma(l_ptr[l], ker23, line[l]);
1139+
line[l] = std::fma(l_ptr[l], ker23, line[l]);
10951140
}
10961141
}
10971142
}
10981143
// apply x kernel to the (interleaved) line and add together (cheap)
10991144
for (uint8_t dx{0}; dx < ns; ++dx) {
1100-
out[0] += line[2 * dx] * ker1[dx];
1101-
out[1] += line[2 * dx + 1] * ker1[dx];
1145+
out[0] = std::fma(line[2 * dx], ker1[dx], out[0]);
1146+
out[1] = std::fma(line[2 * dx + 1], ker1[dx], out[1]);
11021147
}
11031148
} else {
11041149
// ...can be slower since this case only happens with probability
@@ -1160,6 +1205,11 @@ void interp_cube(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
11601205
Barnett 6/16/17.
11611206
No-wrap case sped up for FMA/SIMD by Reinecke 6/19/23
11621207
(see above note in interp_square)
1208+
Barbone July 2024: - moved wrapping logic to interp_cube_wrapped
1209+
- using explicit SIMD vectorization to overcome the out[2] array
1210+
limitation
1211+
The code is largely similar to 2D and 1D interpolation, please see the explanation
1212+
there
11631213
*/
11641214
{
11651215
using arch_t = typename simd_type::arch_type;
@@ -1985,6 +2035,13 @@ FLT fold_rescale(const FLT x, const BIGINT N) noexcept {
19852035
return (result - floor(result)) * FLT(N);
19862036
}
19872037

2038+
template<class simd_type>
2039+
simd_type fold_rescale(const simd_type &x, const BIGINT N) noexcept {
2040+
const simd_type x2pi = FLT(M_1_2PI);
2041+
const simd_type result = xsimd::fma(x, x2pi, simd_type(0.5));
2042+
return (result - xsimd::floor(result)) * simd_type(FLT(N));
2043+
}
2044+
19882045
template<uint8_t ns, uint8_t kerevalmeth, class T, class simd_type, typename... V>
19892046
auto ker_eval(FLT *FINUFFT_RESTRICT ker, const finufft_spread_opts &opts,
19902047
const V... elems) noexcept {

0 commit comments

Comments
 (0)