Skip to content

Commit b997bbc

Browse files
committed
Documentation & changes to kereval0 to make it faster
1 parent 4995826 commit b997bbc

File tree

3 files changed

+40
-24
lines changed

3 files changed

+40
-24
lines changed

src/spreadinterp.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,11 @@ static FINUFFT_ALWAYS_INLINE auto ker_eval(FLT *FINUFFT_RESTRICT ker,
5959
const finufft_spread_opts &opts,
6060
const V... elems) noexcept;
6161
static FINUFFT_ALWAYS_INLINE FLT fold_rescale(FLT x, BIGINT N) noexcept;
62-
static FINUFFT_ALWAYS_INLINE void set_kernel_args(
63-
FLT *args, FLT x, const finufft_spread_opts &opts) noexcept;
62+
template<uint8_t ns>
63+
static FINUFFT_ALWAYS_INLINE void set_kernel_args(FLT *args, FLT x) noexcept;
64+
template<uint8_t N>
6465
static FINUFFT_ALWAYS_INLINE void evaluate_kernel_vector(
65-
FLT *ker, FLT *args, const finufft_spread_opts &opts, int N) noexcept;
66+
FLT *ker, FLT *args, const finufft_spread_opts &opts) noexcept;
6667
template<uint8_t w, class simd_type = xsimd::make_sized_batch_t<
6768
FLT, find_optimal_simd_width<FLT, w>()>> // aka ns
6869
static FINUFFT_ALWAYS_INLINE void eval_kernel_vec_Horner(
@@ -703,16 +704,15 @@ FLT evaluate_kernel(FLT x, const finufft_spread_opts &opts)
703704
return exp((FLT)opts.ES_beta * sqrt((FLT)1.0 - (FLT)opts.ES_c * x * x));
704705
}
705706

706-
void set_kernel_args(FLT *args, FLT x, const finufft_spread_opts &opts) noexcept
707+
template<uint8_t ns>
708+
void set_kernel_args(FLT *args, FLT x) noexcept
707709
// Fills vector args[] with kernel arguments x, x+1, ..., x+ns-1.
708710
// needed for the vectorized kernel eval of Ludvig af K.
709711
{
710-
int ns = opts.nspread;
711712
for (int i = 0; i < ns; i++) args[i] = x + (FLT)i;
712713
}
713-
714-
void evaluate_kernel_vector(FLT *ker, FLT *args, const finufft_spread_opts &opts,
715-
const int N) noexcept
714+
template<uint8_t N>
715+
void evaluate_kernel_vector(FLT *ker, FLT *args, const finufft_spread_opts &opts) noexcept
716716
/* Evaluate ES kernel for a vector of N arguments; by Ludvig af K.
717717
If opts.kerpad true, args and ker must be allocated for Npad, and args is
718718
written to (to pad to length Npad), only first N outputs are correct.
@@ -742,8 +742,7 @@ void evaluate_kernel_vector(FLT *ker, FLT *args, const finufft_spread_opts &opts
742742
if (opts.kerpad) {
743743
// padded part should be zero, in spread_subproblem_nd_kernels, there are
744744
// out of bound writes to trg arrays
745-
for (int i = N; i < Npad; ++i)
746-
ker[i] = 0.0;
745+
for (int i = N; i < Npad; ++i) ker[i] = 0.0;
747746
}
748747
} else {
749748
for (int i = 0; i < N; i++) // dummy for timing only
@@ -1798,8 +1797,8 @@ auto ker_eval(FLT *FINUFFT_RESTRICT ker, const finufft_spread_opts &opts,
17981797
}
17991798
if constexpr (kerevalmeth == 0) {
18001799
alignas(simd_type::arch_type::alignment()) std::array<T, MAX_NSPREAD> kernel_args{};
1801-
set_kernel_args(kernel_args.data(), inputs[i], opts);
1802-
evaluate_kernel_vector(ker + (i * MAX_NSPREAD), kernel_args.data(), opts, ns);
1800+
set_kernel_args<ns>(kernel_args.data(), inputs[i]);
1801+
evaluate_kernel_vector<ns>(ker + (i * MAX_NSPREAD), kernel_args.data(), opts);
18031802
}
18041803
}
18051804
return ker;

test/README

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ finufft{1,2,3}dmany_test{f}: accuracy/speed tests for vectorized transforms,
2222
in a given dimension. Types 1, 2, and 3 are tested.
2323
(exit code 0 is a pass).
2424
Call with no arguments for argument documentation.
25+
finufft3dkernel_test{f} : test of the kernel evaluation methods.
26+
Only 3D NUFFT are tested.
27+
Types 1, 2, and 3 are tested.
28+
It requires Nmodes1 Nmodes2 Nmodes3 Nsrc
29+
sizes of the 3d grid and the number of sources.
30+
Optional arguments are the tolerance, debug flags,
31+
sort flag and upsampfac.
32+
(exit code 0 is a pass).
2533
dumbinputs{f} : test of edge cases, invalid inputs, and plan interface.
2634
No arguments needed (exit code 0 is a pass).
2735
testutils{f} : test of utils module.

test/finufft3dkernel_test.cpp

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,25 @@
44
using namespace std;
55
using namespace finufft::utils;
66

7-
const char *help[] = {"Tester for FINUFFT in 3d, all 3 types, either precision.",
8-
"",
9-
"Usage: finufft3d_test Nmodes1 Nmodes2 Nmodes3 Nsrc [tol [debug "
10-
"[spread_sort [upsampfac]]]]",
11-
"\teg:\tfinufft3d_test 100 200 50 1e6 1e-12 0 2 0.0 1e-11",
12-
"\tnotes:\tif errfail present, exit code 1 if any error > errfail",
13-
NULL};
14-
// Barnett 2/2/17 onwards.
15-
7+
const char *help[] = {
8+
"Tester for FINUFFT in 3d, all 3 types, either precision.",
9+
"",
10+
"Usage: finufft3d_test Nmodes1 Nmodes2 Nmodes3 Nsrc",
11+
"\t[tol] error tolerance (default 1e-6)",
12+
"\t[debug] (default 0) 0: silent, 1: text, 2: as 1 but also spreader",
13+
"\t[spread_sort] (default 2) 0: don't sort NU pts, 1: do, 2: auto",
14+
"\t[upsampfac] (default 2.0)",
15+
"\teg: finufft3d_test 100 200 50 1e6 1e-12 0 2 0.0",
16+
"\tnotes: exit code 1 if any error > tol",
17+
nullptr};
18+
/**
19+
* @brief Test the 3D NUFFT of type 1, 2, and 3.
20+
* It evaluates the error of the kernel evaluation methods.
21+
* It uses err(a,b)=||a-b||_2 / ||a||_2 as the error metric.
22+
* It return FINUFFT error code if it is not 0.
23+
* It returns 1 if any error exceeds tol.
24+
* It returns 0 if test passes.
25+
*/
1626
int main(int argc, char *argv[]) {
1727
BIGINT M, N1, N2, N3; // M = # srcs, N1,N2,N3 = # modes
1828
double w, tol = 1e-6; // default
@@ -142,9 +152,6 @@ int main(int argc, char *argv[]) {
142152
std::vector<FLT> s(N); // targ freqs (1-cmpt)
143153
std::vector<FLT> t(N); // targ freqs (2-cmpt)
144154
std::vector<FLT> u(N); // targ freqs (3-cmpt)
145-
FLT S1 = (FLT)N1 / 2; // choose freq range sim to type 1
146-
FLT S2 = (FLT)N2 / 2;
147-
FLT S3 = (FLT)N3 / 2;
148155

149156
timer.restart();
150157
printf("kerevalmeth 0:\n");
@@ -171,5 +178,7 @@ int main(int argc, char *argv[]) {
171178
err = relerrtwonorm(N, F0.data(), F1.data());
172179
errmax = max(err, errmax);
173180
printf("\ttype 3 rel l2-err in F is %.3g\n", err);
181+
// return 1 if any error exceeds tol
182+
// or return finufft error code if it is not 0
174183
return (errmax > tol);
175184
}

0 commit comments

Comments
 (0)