Skip to content

Commit 223edc6

Browse files
committed
minor vectorization in 3D
1 parent 52150f0 commit 223edc6

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

src/spreadinterp.cpp

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -800,8 +800,9 @@ Two upsampfacs implemented. Params must match ref formula. Barnett 4/24/18 */
800800
}
801801

802802
template<uint8_t ns>
803-
static void interp_line_wrap(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
804-
const BIGINT i1, const UBIGINT N1) {
803+
FINUFFT_NEVER_INLINE static void interp_line_wrap(FLT *FINUFFT_RESTRICT target,
804+
const FLT *du, const FLT *ker,
805+
const BIGINT i1, const UBIGINT N1) {
805806
/* This function is called when the kernel wraps around the grid. It is
806807
slower than interp_line.
807808
M. Barbone July 2024: - moved the logic to a separate function
@@ -954,9 +955,9 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
954955
}
955956

956957
template<uint8_t ns, class simd_type>
957-
static void interp_square_wrap(FLT *FINUFFT_RESTRICT target, const FLT *du,
958-
const FLT *ker1, const FLT *ker2, const BIGINT i1,
959-
const BIGINT i2, const UBIGINT N1, const UBIGINT N2) {
958+
FINUFFT_NEVER_INLINE static void interp_square_wrap(
959+
FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1, const FLT *ker2,
960+
const BIGINT i1, const BIGINT i2, const UBIGINT N1, const UBIGINT N2) {
960961
/*
961962
* This function is called when the kernel wraps around the grid. It is slower than
962963
* the non wrapping version.
@@ -1137,10 +1138,10 @@ void interp_square(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
11371138
}
11381139

11391140
template<uint8_t ns, class simd_type>
1140-
static void interp_cube_wrapped(FLT *FINUFFT_RESTRICT target, const FLT *du,
1141-
const FLT *ker1, const FLT *ker2, const FLT *ker3,
1142-
const BIGINT i1, const BIGINT i2, const BIGINT i3,
1143-
const UBIGINT N1, const UBIGINT N2, const UBIGINT N3) {
1141+
FINUFFT_NEVER_INLINE static void interp_cube_wrapped(
1142+
FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1, const FLT *ker2,
1143+
const FLT *ker3, const BIGINT i1, const BIGINT i2, const BIGINT i3, const UBIGINT N1,
1144+
const UBIGINT N2, const UBIGINT N3) {
11441145
/*
11451146
* This function is called when the kernel wraps around the cube.
11461147
* Similarly to 2D and 1D wrapping, this is slower than the non wrapping version.
@@ -1246,6 +1247,7 @@ void interp_cube(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
12461247
static constexpr auto padding = get_padding<FLT, 2 * ns>();
12471248
static constexpr auto alignment = arch_t::alignment();
12481249
static constexpr auto simd_size = simd_type::size;
1250+
static constexpr auto ker23_size = (ns + simd_size - 1) & -simd_size;
12491251
static constexpr uint8_t line_vectors = (2 * ns + padding) / simd_size;
12501252
const auto in_bounds_1 = (i1 >= 0) & (i1 + ns <= N1);
12511253
const auto in_bounds_2 = (i2 >= 0) & (i2 + ns <= N2);
@@ -1254,13 +1256,22 @@ void interp_cube(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
12541256
if (in_bounds_1 && in_bounds_2 && in_bounds_3 && (i1 + ns + (padding + 1) / 2 < N1)) {
12551257
const auto line = [N1, N2, i1, i2, i3, ker2, ker3, du]() constexpr noexcept {
12561258
std::array<simd_type, line_vectors> line{0}, du_pts{};
1259+
alignas(alignment) std::array<FLT, ker23_size> ker23_array{};
12571260
const UBIGINT base_oz = N1 * N2 * i3; // Move invariant part outside the loop
12581261
for (uint8_t dz{0}; dz < ns; ++dz) {
12591262
const auto oz = base_oz + N1 * N2 * dz; // Only the dz part is inside the loop
1260-
const auto base_du_ptr = du + 2 * (oz + N1 * i2 + UBIGINT(i1));
1263+
const auto base_du_ptr = du + 2 * UBIGINT(oz + N1 * i2 + i1);
1264+
{
1265+
const simd_type ker3_v{ker3[dz]};
1266+
for (uint8_t dy{0}; dy < ns; dy += simd_size) {
1267+
const auto ker2_v = simd_type::load_aligned(ker2 + dy);
1268+
const auto ker23_v = ker2_v * ker3_v;
1269+
ker23_v.store_aligned(ker23_array.data() + dy);
1270+
}
1271+
}
12611272
for (uint8_t dy{0}; dy < ns; ++dy) {
12621273
const auto du_ptr = base_du_ptr + 2 * N1 * dy; // (see above)
1263-
const simd_type ker23_v{ker2[dy] * ker3[dz]};
1274+
const simd_type ker23_v{ker23_array[dy]};
12641275
// First loop: Load all du_pt into the du_pts array
12651276
for (uint8_t l{0}; l < line_vectors; ++l) {
12661277
du_pts[l] = simd_type::load_unaligned(l * simd_size + du_ptr);

0 commit comments

Comments
 (0)