@@ -800,8 +800,9 @@ Two upsampfacs implemented. Params must match ref formula. Barnett 4/24/18 */
800800}
801801
802802template <uint8_t ns>
803- static void interp_line_wrap (FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
804- const BIGINT i1, const UBIGINT N1) {
803+ FINUFFT_NEVER_INLINE static void interp_line_wrap (FLT *FINUFFT_RESTRICT target,
804+ const FLT *du, const FLT *ker,
805+ const BIGINT i1, const UBIGINT N1) {
805806 /* This function is called when the kernel wraps around the grid. It is
806807 slower than interp_line.
807808 M. Barbone July 2024: - moved the logic to a separate function
@@ -954,9 +955,9 @@ void interp_line(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker,
954955}
955956
956957template <uint8_t ns, class simd_type >
957- static void interp_square_wrap (FLT *FINUFFT_RESTRICT target, const FLT *du,
958- const FLT *ker1 , const FLT *ker2 , const BIGINT i1 ,
959- const BIGINT i2, const UBIGINT N1, const UBIGINT N2) {
958+ FINUFFT_NEVER_INLINE static void interp_square_wrap (
959+ FLT *FINUFFT_RESTRICT target, const FLT *du , const FLT *ker1 , const FLT *ker2 ,
960+ const BIGINT i1, const BIGINT i2, const UBIGINT N1, const UBIGINT N2) {
960961 /*
961962 * This function is called when the kernel wraps around the grid. It is slower than
962963 * the non wrapping version.
@@ -1137,10 +1138,10 @@ void interp_square(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
11371138}
11381139
11391140template <uint8_t ns, class simd_type >
1140- static void interp_cube_wrapped (FLT *FINUFFT_RESTRICT target, const FLT *du,
1141- const FLT *ker1 , const FLT *ker2 , const FLT *ker3 ,
1142- const BIGINT i1, const BIGINT i2, const BIGINT i3,
1143- const UBIGINT N1, const UBIGINT N2, const UBIGINT N3) {
1141+ FINUFFT_NEVER_INLINE static void interp_cube_wrapped (
1142+ FLT *FINUFFT_RESTRICT target, const FLT *du , const FLT *ker1 , const FLT *ker2 ,
1143+ const FLT *ker3, const BIGINT i1, const BIGINT i2, const BIGINT i3, const UBIGINT N1 ,
1144+ const UBIGINT N2, const UBIGINT N3) {
11441145 /*
11451146 * This function is called when the kernel wraps around the cube.
11461147 * Similarly to 2D and 1D wrapping, this is slower than the non wrapping version.
@@ -1246,6 +1247,7 @@ void interp_cube(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
12461247 static constexpr auto padding = get_padding<FLT, 2 * ns>();
12471248 static constexpr auto alignment = arch_t::alignment ();
12481249 static constexpr auto simd_size = simd_type::size;
1250+ static constexpr auto ker23_size = (ns + simd_size - 1 ) & -simd_size;
12491251 static constexpr uint8_t line_vectors = (2 * ns + padding) / simd_size;
12501252 const auto in_bounds_1 = (i1 >= 0 ) & (i1 + ns <= N1);
12511253 const auto in_bounds_2 = (i2 >= 0 ) & (i2 + ns <= N2);
@@ -1254,13 +1256,22 @@ void interp_cube(FLT *FINUFFT_RESTRICT target, const FLT *du, const FLT *ker1,
12541256 if (in_bounds_1 && in_bounds_2 && in_bounds_3 && (i1 + ns + (padding + 1 ) / 2 < N1)) {
12551257 const auto line = [N1, N2, i1, i2, i3, ker2, ker3, du]() constexpr noexcept {
12561258 std::array<simd_type, line_vectors> line{0 }, du_pts{};
1259+ alignas (alignment) std::array<FLT, ker23_size> ker23_array{};
12571260 const UBIGINT base_oz = N1 * N2 * i3; // Move invariant part outside the loop
12581261 for (uint8_t dz{0 }; dz < ns; ++dz) {
12591262 const auto oz = base_oz + N1 * N2 * dz; // Only the dz part is inside the loop
1260- const auto base_du_ptr = du + 2 * (oz + N1 * i2 + UBIGINT (i1));
1263+ const auto base_du_ptr = du + 2 * UBIGINT (oz + N1 * i2 + i1);
1264+ {
1265+ const simd_type ker3_v{ker3[dz]};
1266+ for (uint8_t dy{0 }; dy < ns; dy += simd_size) {
1267+ const auto ker2_v = simd_type::load_aligned (ker2 + dy);
1268+ const auto ker23_v = ker2_v * ker3_v;
1269+ ker23_v.store_aligned (ker23_array.data () + dy);
1270+ }
1271+ }
12611272 for (uint8_t dy{0 }; dy < ns; ++dy) {
12621273 const auto du_ptr = base_du_ptr + 2 * N1 * dy; // (see above)
1263- const simd_type ker23_v{ker2 [dy] * ker3[dz ]};
1274+ const simd_type ker23_v{ker23_array [dy]};
12641275 // First loop: Load all du_pt into the du_pts array
12651276 for (uint8_t l{0 }; l < line_vectors; ++l) {
12661277 du_pts[l] = simd_type::load_unaligned (l * simd_size + du_ptr);
0 commit comments