Skip to content

Commit 0ce3cb9

Browse files
committed
vectorized interp 1d
1 parent 684a447 commit 0ce3cb9

File tree

2 files changed

+63
-23
lines changed

2 files changed

+63
-23
lines changed

include/finufft/defs.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,19 @@
4444
#define FINUFFT_NEVER_INLINE __declspec(noinline)
4545
#define FINUFFT_RESTRICT __restrict
4646
#define FINUFFT_UNREACHABLE __assume(0)
47-
47+
#define FINUFFT_UNLIKELY(x) (x)
4848
#elif defined(__GNUC__) || defined(__clang__)
4949
#define FINUFFT_ALWAYS_INLINE __attribute__((always_inline)) inline
5050
#define FINUFFT_NEVER_INLINE __attribute__((noinline))
5151
#define FINUFFT_RESTRICT __restrict__
5252
#define FINUFFT_UNREACHABLE __builtin_unreachable()
53+
#define FINUFFT_UNLIKELY(x) __builtin_expect(!!(x), 0)
5354
#else
5455
#define FINUFFT_ALWAYS_INLINE inline
5556
#define FINUFFT_NEVER_INLINE
5657
#define FINUFFT_RESTRICT
5758
#define FINUFFT_UNREACHABLE
59+
#define FINUFFT_UNLIKELY(x) (x)
5860
#endif
5961

6062
// ------------- Library-wide algorithm parameter settings ----------------

src/spreadinterp.cpp

Lines changed: 60 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ namespace { // anonymous namespace for internal structs equivalent to declaring
2424
// static
2525
struct zip_low;
2626
struct zip_hi;
27+
struct select_even;
28+
struct select_odd;
2729
// forward declaration to clean up the code and be able to use this everywhere in the file
2830
template<class T, uint8_t N, uint8_t K = N> static constexpr auto BestSIMDHelper();
2931
template<class T, uint8_t N> constexpr auto GetPaddedSIMDWidth();
@@ -43,6 +45,12 @@ constexpr auto zip_low_index =
4345
template<class arch_t>
4446
constexpr auto zip_hi_index =
4547
xsimd::make_batch_constant<xsimd::as_unsigned_integer_t<FLT>, arch_t, zip_hi>();
48+
template<class arch_t>
49+
constexpr auto select_even_mask =
50+
xsimd::make_batch_bool_constant<FLT, arch_t, select_even>();
51+
template<class arch_t>
52+
constexpr auto select_odd_mask =
53+
xsimd::make_batch_bool_constant<FLT, arch_t, select_odd>();
4654
template<typename T, std::size_t N, std::size_t M, std::size_t PaddedM>
4755
constexpr std::array<std::array<T, PaddedM>, N> pad_2D_array_with_zeros(
4856
const std::array<std::array<T, M>, N> &input) noexcept;
@@ -67,14 +75,14 @@ template<uint8_t w, class simd_type = xsimd::make_sized_batch_t<
6775
FLT, find_optimal_simd_width<FLT, w>()>> // aka ns
6876
static FINUFFT_ALWAYS_INLINE void eval_kernel_vec_Horner(
6977
FLT *FINUFFT_RESTRICT ker, FLT x, const finufft_spread_opts &opts) noexcept;
70-
template<uint8_t ns>
71-
static void interp_line(FLT *FINUFFT_RESTRICT out, const FLT *du, const FLT *ker,
78+
template<uint8_t ns, class simd_type = PaddedSIMD<FLT, 2 * ns>>
79+
static void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
7280
BIGINT i1, BIGINT N1);
7381
template<uint8_t ns>
74-
static void interp_square(FLT *FINUFFT_RESTRICT out, const FLT *du, const FLT *ker1,
82+
static void interp_square(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
7583
const FLT *ker2, BIGINT i1, BIGINT i2, BIGINT N1, BIGINT N2);
7684
template<uint8_t ns>
77-
static void interp_cube(FLT *FINUFFT_RESTRICT out, const FLT *du, const FLT *ker1,
85+
static void interp_cube(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
7886
const FLT *ker2, const FLT *ker3, BIGINT i1, BIGINT i2, BIGINT i3,
7987
BIGINT N1, BIGINT N2, BIGINT N3);
8088
static void spread_subproblem_1d(BIGINT off1, BIGINT size1, FLT *du0, BIGINT M0, FLT *kx0,
@@ -454,10 +462,10 @@ FINUFFT_NEVER_INLINE static int interpSorted_kernel(
454462
// Interpolate to NU pts in sorted order from a uniform grid.
455463
// See spreadinterp() for doc.
456464
{
457-
using simd_type = xsimd::batch<FLT>;
465+
using simd_type = PaddedSIMD<FLT, 2 * ns>;
458466
using arch_t = typename simd_type::arch_type;
459467
static constexpr auto padding = get_padding<FLT, 2 * ns>();
460-
static constexpr auto alignment = simd_type::arch_type::alignment();
468+
static constexpr auto alignment = arch_t::alignment();
461469
static constexpr auto simd_size = simd_type::size;
462470
static constexpr auto ns2 = ns * FLT(0.5); // half spread width, used as stencil shift
463471

@@ -521,15 +529,16 @@ FINUFFT_NEVER_INLINE static int interpSorted_kernel(
521529
if (!(opts.flags & TF_OMIT_SPREADING)) {
522530
switch (ndims) {
523531
case 1:
524-
ker_eval<ns, kerevalmeth, FLT>(kernel_values.data(), opts, x1);
525-
interp_line<ns>(target, data_uniform, ker1, i1, N1);
532+
ker_eval<ns, kerevalmeth, FLT, simd_type>(kernel_values.data(), opts, x1);
533+
interp_line<ns, simd_type>(target, data_uniform, ker1, i1, N1);
526534
break;
527535
case 2:
528-
ker_eval<ns, kerevalmeth, FLT>(kernel_values.data(), opts, x1, x2);
536+
ker_eval<ns, kerevalmeth, FLT, simd_type>(kernel_values.data(), opts, x1, x2);
529537
interp_square<ns>(target, data_uniform, ker1, ker2, i1, i2, N1, N2);
530538
break;
531539
case 3:
532-
ker_eval<ns, kerevalmeth, FLT>(kernel_values.data(), opts, x1, x2, x3);
540+
ker_eval<ns, kerevalmeth, FLT, simd_type>(kernel_values.data(), opts, x1, x2,
541+
x3);
533542
interp_cube<ns>(target, data_uniform, ker1, ker2, ker3, i1, i2, i3, N1, N2,
534543
N3);
535544
break;
@@ -788,9 +797,9 @@ Two upsampfacs implemented. Params must match ref formula. Barnett 4/24/18 */
788797
}
789798
}
790799

791-
template<uint8_t ns>
792-
void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker, BIGINT i1,
793-
const BIGINT N1)
800+
template<uint8_t ns, class simd_type>
801+
void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
802+
const BIGINT i1, const BIGINT N1)
794803
/* 1D interpolate complex values from size-ns block of the du (uniform grid
795804
data) array to a single complex output value "target", using as weights the
796805
1d kernel evaluation list ker1.
@@ -808,38 +817,60 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker, BI
808817
Barnett 6/16/17.
809818
*/
810819
{
820+
811821
FLT out[] = {0.0, 0.0};
812822
BIGINT j = i1;
813-
if (i1 < 0) { // wraps at left
823+
if (FINUFFT_UNLIKELY(i1 < 0)) { // wraps at left
814824
j += N1;
815-
for (int dx = 0; dx < -i1; ++dx) {
825+
for (UBIGINT dx = 0; dx < -i1; ++dx) {
816826
out[0] += du[2 * j] * ker[dx];
817827
out[1] += du[2 * j + 1] * ker[dx];
818828
++j;
819829
}
820830
j -= N1;
821-
for (int dx = -i1; dx < ns; ++dx) {
831+
for (UBIGINT dx = -i1; dx < ns; ++dx) {
822832
out[0] += du[2 * j] * ker[dx];
823833
out[1] += du[2 * j + 1] * ker[dx];
824834
++j;
825835
}
826-
} else if (i1 + ns >= N1) { // wraps at right
836+
} else if (FINUFFT_UNLIKELY(i1 + ns >= N1)) { // wraps at right
827837
for (int dx = 0; dx < N1 - i1; ++dx) {
828838
out[0] += du[2 * j] * ker[dx];
829839
out[1] += du[2 * j + 1] * ker[dx];
830840
++j;
831841
}
832842
j -= N1;
833-
for (int dx = N1 - i1; dx < ns; ++dx) {
843+
for (UBIGINT dx = N1 - i1; dx < ns; ++dx) {
834844
out[0] += du[2 * j] * ker[dx];
835845
out[1] += du[2 * j + 1] * ker[dx];
836846
++j;
837847
}
838848
} else { // doesn't wrap
839-
for (int dx = 0; dx < ns; ++dx) {
840-
out[0] += du[2 * j] * ker[dx];
841-
out[1] += du[2 * j + 1] * ker[dx];
842-
++j;
849+
using arch_t = typename simd_type::arch_type;
850+
static constexpr auto padding = get_padding<FLT, 2 * ns>();
851+
static constexpr auto alignment = arch_t::alignment();
852+
static constexpr auto simd_size = simd_type::size;
853+
static constexpr auto regular_part = (2 * ns + padding) & (-(2 * simd_size));
854+
simd_type res{0};
855+
for (uint8_t dx{0}; dx < regular_part; dx += 2 * simd_size) {
856+
const auto ker_v = simd_type::load_aligned(ker + dx / 2);
857+
const auto du_pt0 = simd_type::load_unaligned(du + dx);
858+
const auto du_pt1 = simd_type::load_unaligned(du + dx + simd_size);
859+
const auto ker0low = xsimd::swizzle(ker_v, zip_low_index<arch_t>);
860+
const auto ker0hi = xsimd::swizzle(ker_v, zip_hi_index<arch_t>);
861+
res = xsimd::fma(ker0low, du_pt0, xsimd::fma(ker0hi, du_pt1, res));
862+
}
863+
if constexpr (regular_part < 2 * ns) {
864+
const auto ker0 = simd_type::load_unaligned(ker + (regular_part / 2));
865+
const auto du_pt = simd_type::load_unaligned(du + regular_part);
866+
const auto ker0low = xsimd::swizzle(ker0, zip_low_index<arch_t>);
867+
res = xsimd::fma(ker0low, du_pt, res);
868+
}
869+
alignas(alignment) std::array<FLT, simd_size> res_array{};
870+
res.store_aligned(res_array.data());
871+
for (uint8_t i{0}; i < simd_size; i += 2) {
872+
out[0] += res_array[i];
873+
out[1] += res_array[i + 1];
843874
}
844875
}
845876
target[0] = out[0];
@@ -1929,6 +1960,13 @@ struct zip_hi {
19291960
}
19301961
};
19311962

1963+
struct select_even {
1964+
static constexpr bool get(unsigned index, unsigned /*size*/) { return index % 2 == 0; }
1965+
};
1966+
struct select_odd {
1967+
static constexpr bool get(unsigned index, unsigned /*size*/) { return index % 2 == 1; }
1968+
};
1969+
19321970
void print_subgrid_info(int ndims, BIGINT offset1, BIGINT offset2, BIGINT offset3,
19331971
BIGINT padded_size1, BIGINT size1, BIGINT size2, BIGINT size3,
19341972
BIGINT M0) {

0 commit comments

Comments
 (0)