|
| 1 | +#include <finufft/test_defs.h> |
| 2 | +// this enforces recompilation, responding to SINGLE... |
| 3 | +#include "directft/dirft3d.cpp" |
| 4 | +using namespace std; |
| 5 | +using namespace finufft::utils; |
| 6 | + |
| 7 | +const char *help[] = { |
| 8 | + "Test spread_kerevalmeth=0 & 1 match, for 3 types of 3D transf, either prec.", |
| 9 | + "Usage: finufft3dkernel_test Nmodes1 Nmodes2 Nmodes3 Nsrc", |
| 10 | + "\t[tol] error tolerance (default 1e-6)", |
| 11 | + "\t[debug] (default 0) 0: silent, 1: text, 2: as 1 but also spreader", |
| 12 | + "\t[spread_sort] (default 2) 0: don't sort NU pts, 1: do, 2: auto", |
| 13 | + "\t[upsampfac] (default 2.0)", |
| 14 | + "\teg: finufft3dkernel_test 100 200 50 1e6 1e-12 0 2 0.0", |
| 15 | + "\tnotes: exit code 1 if any error > tol", |
| 16 | + nullptr}; |
| 17 | +/** |
| 18 | + * @brief Test the 3D NUFFT of type 1, 2, and 3. |
| 19 | + * It evaluates the error of the kernel evaluation methods. |
| 20 | + * It uses err(a,b)=||a-b||_2 / ||a||_2 as the error metric. |
| 21 | + * It return FINUFFT error code if it is not 0. |
| 22 | + * It returns 1 if any error exceeds tol. |
| 23 | + * It returns 0 if test passes. |
| 24 | + */ |
| 25 | +int main(int argc, char *argv[]) { |
| 26 | + BIGINT M, N1, N2, N3; // M = # srcs, N1,N2,N3 = # modes |
| 27 | + double w, tol = 1e-6; // default |
| 28 | + double err, errmax = 0; |
| 29 | + finufft_opts opts0, opts1; |
| 30 | + FINUFFT_DEFAULT_OPTS(&opts0); |
| 31 | + FINUFFT_DEFAULT_OPTS(&opts1); |
| 32 | + opts0.spread_kerevalmeth = 0; |
| 33 | + opts1.spread_kerevalmeth = 1; |
| 34 | + // opts.fftw = FFTW_MEASURE; // change from usual FFTW_ESTIMATE |
| 35 | + // opts.spread_max_sp_size = 3e4; // override test |
| 36 | + // opts.spread_nthr_atomic = 15; // " |
| 37 | + int isign = +1; // choose which exponential sign to test |
| 38 | + if (argc < 5 || argc > 10) { |
| 39 | + for (int i = 0; help[i]; ++i) fprintf(stderr, "%s\n", help[i]); |
| 40 | + return 2; |
| 41 | + } |
| 42 | + sscanf(argv[1], "%lf", &w); |
| 43 | + N1 = (BIGINT)w; |
| 44 | + sscanf(argv[2], "%lf", &w); |
| 45 | + N2 = (BIGINT)w; |
| 46 | + sscanf(argv[3], "%lf", &w); |
| 47 | + N3 = (BIGINT)w; |
| 48 | + sscanf(argv[4], "%lf", &w); |
| 49 | + M = (BIGINT)w; |
| 50 | + if (argc > 5) sscanf(argv[5], "%lf", &tol); |
| 51 | + if (argc > 6) sscanf(argv[6], "%d", &opts0.debug); // can be 0,1 or 2 |
| 52 | + opts0.spread_debug = (opts0.debug > 1) ? 1 : 0; // see output from spreader |
| 53 | + if (argc > 7) sscanf(argv[7], "%d", &opts0.spread_sort); |
| 54 | + if (argc > 8) { |
| 55 | + sscanf(argv[8], "%lf", &w); |
| 56 | + opts0.upsampfac = (FLT)w; |
| 57 | + } |
| 58 | + |
| 59 | + opts1 = opts0; |
| 60 | + opts0.spread_kerevalmeth = 0; |
| 61 | + opts1.spread_kerevalmeth = 1; |
| 62 | + |
| 63 | + cout << scientific << setprecision(15); |
| 64 | + const BIGINT N = N1 * N2 * N3; |
| 65 | + |
| 66 | + std::vector<FLT> x(M); // NU pts x coords |
| 67 | + std::vector<FLT> y(M); // NU pts y coords |
| 68 | + std::vector<FLT> z(M); // NU pts z coords |
| 69 | + std::vector<CPX> c0(M), c1(N); // strengths |
| 70 | + std::vector<CPX> F0(N); // mode ampls kereval 0 |
| 71 | + std::vector<CPX> F1(N); // mode ampls kereval 1 |
| 72 | +#pragma omp parallel |
| 73 | + { |
| 74 | + unsigned int se = MY_OMP_GET_THREAD_NUM(); // needed for parallel random #s |
| 75 | +#pragma omp for schedule(static, TEST_RANDCHUNK) |
| 76 | + for (BIGINT j = 0; j < M; ++j) { |
| 77 | + x[j] = M_PI * randm11r(&se); |
| 78 | + y[j] = M_PI * randm11r(&se); |
| 79 | + z[j] = M_PI * randm11r(&se); |
| 80 | + c0[j] = crandm11r(&se); |
| 81 | + } |
| 82 | + } |
| 83 | + c1 = c0; // copy strengths |
| 84 | + printf("test 3d type 1:\n"); // -------------- type 1 |
| 85 | + printf("kerevalmeth 0:\n"); |
| 86 | + CNTime timer{}; |
| 87 | + timer.start(); |
| 88 | + int ier = FINUFFT3D1(M, x.data(), y.data(), z.data(), c0.data(), isign, tol, N1, N2, N3, |
| 89 | + F0.data(), &opts0); |
| 90 | + double ti = timer.elapsedsec(); |
| 91 | + if (ier > 1) { |
| 92 | + printf("error (ier=%d)!\n", ier); |
| 93 | + return ier; |
| 94 | + } else |
| 95 | + printf(" %lld NU pts to (%lld,%lld,%lld) modes in %.3g s \t%.3g NU pts/s\n", |
| 96 | + (long long)M, (long long)N1, (long long)N2, (long long)N3, ti, M / ti); |
| 97 | + printf("kerevalmeth 1:\n"); |
| 98 | + timer.restart(); |
| 99 | + ier = FINUFFT3D1(M, x.data(), y.data(), z.data(), c0.data(), isign, tol, N1, N2, N3, |
| 100 | + F1.data(), &opts1); |
| 101 | + ti = timer.elapsedsec(); |
| 102 | + if (ier > 1) { |
| 103 | + printf("error (ier=%d)!\n", ier); |
| 104 | + return ier; |
| 105 | + } else |
| 106 | + printf(" %lld NU pts to (%lld,%lld,%lld) modes in %.3g s \t%.3g NU pts/s\n", |
| 107 | + (long long)M, (long long)N1, (long long)N2, (long long)N3, ti, M / ti); |
| 108 | + |
| 109 | + err = relerrtwonorm(N, F0.data(), F1.data()); |
| 110 | + errmax = max(err, errmax); |
| 111 | + printf("\ttype 1 rel l2-err in F is %.3g\n", err); |
| 112 | + // copy F0 to F1 so that we can test type 2 |
| 113 | + F1 = F0; |
| 114 | + printf("kerevalmeth 0:\n"); |
| 115 | + timer.restart(); |
| 116 | + ier = FINUFFT3D2(M, x.data(), y.data(), z.data(), c0.data(), isign, tol, N1, N2, N3, |
| 117 | + F0.data(), &opts0); |
| 118 | + ti = timer.elapsedsec(); |
| 119 | + if (ier > 1) { |
| 120 | + printf("error (ier=%d)!\n", ier); |
| 121 | + return ier; |
| 122 | + } else |
| 123 | + printf(" (%lld,%lld,%lld) modes to %lld NU pts in %.3g s \t%.3g NU pts/s\n", |
| 124 | + (long long)N1, (long long)N2, (long long)N3, (long long)M, ti, M / ti); |
| 125 | + printf("kerevalmeth 1:\n"); |
| 126 | + timer.restart(); |
| 127 | + ier = FINUFFT3D2(M, x.data(), y.data(), z.data(), c1.data(), isign, tol, N1, N2, N3, |
| 128 | + F0.data(), &opts1); |
| 129 | + ti = timer.elapsedsec(); |
| 130 | + if (ier > 1) { |
| 131 | + printf("error (ier=%d)!\n", ier); |
| 132 | + return ier; |
| 133 | + } else |
| 134 | + printf(" (%lld,%lld,%lld) modes to %lld NU pts in %.3g s \t%.3g NU pts/s\n", |
| 135 | + (long long)N1, (long long)N2, (long long)N3, (long long)M, ti, M / ti); |
| 136 | + err = relerrtwonorm(M, c0.data(), c1.data()); |
| 137 | + errmax = std::max(err, errmax); |
| 138 | + printf("\ttype 2 rel l2-err in c is %.3g\n", err); |
| 139 | + |
| 140 | + printf("test 3d type 3:\n"); // -------------- type 3 |
| 141 | +#pragma omp parallel |
| 142 | + { |
| 143 | + unsigned int se = MY_OMP_GET_THREAD_NUM(); |
| 144 | +#pragma omp for schedule(static, TEST_RANDCHUNK) |
| 145 | + for (BIGINT j = 0; j < M; ++j) { |
| 146 | + x[j] = 2.0 + M_PI * randm11r(&se); // new x_j srcs, offset from origin |
| 147 | + y[j] = -3.0 + M_PI * randm11r(&se); // " y_j |
| 148 | + z[j] = 1.0 + M_PI * randm11r(&se); // " z_j |
| 149 | + } |
| 150 | + } |
| 151 | + std::vector<FLT> s(N); // targ freqs (1-cmpt) |
| 152 | + std::vector<FLT> t(N); // targ freqs (2-cmpt) |
| 153 | + std::vector<FLT> u(N); // targ freqs (3-cmpt) |
| 154 | + |
| 155 | + timer.restart(); |
| 156 | + printf("kerevalmeth 0:\n"); |
| 157 | + ier = FINUFFT3D3(M, x.data(), y.data(), z.data(), c0.data(), isign, tol, N, s.data(), |
| 158 | + t.data(), u.data(), F0.data(), &opts0); |
| 159 | + ti = timer.elapsedsec(); |
| 160 | + if (ier > 1) { |
| 161 | + printf("error (ier=%d)!\n", ier); |
| 162 | + return ier; |
| 163 | + } else |
| 164 | + printf("\t%lld NU to %lld NU in %.3g s \t%.3g tot NU pts/s\n", (long long)M, |
| 165 | + (long long)N, ti, (M + N) / ti); |
| 166 | + timer.restart(); |
| 167 | + printf("kerevalmeth 1:\n"); |
| 168 | + ier = FINUFFT3D3(M, x.data(), y.data(), z.data(), c0.data(), isign, tol, N, s.data(), |
| 169 | + t.data(), u.data(), F1.data(), &opts1); |
| 170 | + ti = timer.elapsedsec(); |
| 171 | + if (ier > 1) { |
| 172 | + printf("error (ier=%d)!\n", ier); |
| 173 | + return ier; |
| 174 | + } else |
| 175 | + printf("\t%lld NU to %lld NU in %.3g s \t%.3g tot NU pts/s\n", (long long)M, |
| 176 | + (long long)N, ti, (M + N) / ti); |
| 177 | + err = relerrtwonorm(N, F0.data(), F1.data()); |
| 178 | + errmax = max(err, errmax); |
| 179 | + printf("\ttype 3 rel l2-err in F is %.3g\n", err); |
| 180 | + // return 1 if any error exceeds tol |
| 181 | + // or return finufft error code if it is not 0 |
| 182 | + return (errmax > tol); |
| 183 | +} |
0 commit comments