@@ -222,49 +222,56 @@ static void onedim_fseries_kernel(BIGINT nf, std::vector<T> &fwkerhalf,
222222 }
223223}
224224
225- template <typename T>
226- static void onedim_nuft_kernel (BIGINT nk, const std::vector<T> &k, std::vector<T> &phihat,
227- const finufft_spread_opts &opts)
228- /*
229- Approximates exact 1D Fourier transform of cnufftspread's real symmetric
230- kernel, directly via q-node quadrature on Euler-Fourier formula, exploiting
231- narrowness of kernel. Evaluates at set of arbitrary freqs k in [-pi, pi),
232- for a kernel with x measured in grid-spacings. (See previous routine for
233- FT definition).
225+ template <typename T> class KernelFseries {
226+ private:
227+ std::vector<T> z, f;
228+
229+ public:
230+ /*
231+ Approximates exact 1D Fourier transform of cnufftspread's real symmetric
232+ kernel, directly via q-node quadrature on Euler-Fourier formula, exploiting
233+ narrowness of kernel. Evaluates at set of arbitrary freqs k in [-pi, pi),
234+ for a kernel with x measured in grid-spacings. (See previous routine for
235+ FT definition).
236+
237+ Inputs:
238+ opts - spreading opts object, needed to eval kernel (must be already set up)
239+
240+ Barnett 2/8/17. openmp since cos slow 2/9/17.
241+ To do (Nov 2024): replace evaluate_kernel by evaluate_kernel_horner.
242+ */
243+ KernelFseries (const finufft_spread_opts &opts) {
244+ T J2 = opts.nspread / 2.0 ; // J/2, half-width of ker z-support
245+ // # quadr nodes in z (from 0 to J/2; reflections will be added)...
246+ int q = (int )(2 + 2.0 * J2); // > pi/2 ratio. cannot exceed MAX_NQUAD
247+ if (opts.debug ) printf (" q (# ker FT quadr pts) = %d\n " , q);
248+ std::vector<double > Z (2 * q), W (2 * q);
249+ legendre_compute_glr (2 * q, Z.data (), W.data ()); // only half the nodes used, eg on
250+ // (0,1)
251+ z.resize (q);
252+ f.resize (q);
253+ for (int n = 0 ; n < q; ++n) {
254+ z[n] = T (Z[n] * J2); // quadr nodes for [0,J/2]
255+ f[n] = J2 * T (W[n]) * evaluate_kernel (z[n], opts); // w/ quadr weights
256+ }
257+ }
234258
235- Inputs:
236- nk - number of freqs
237- k - frequencies, dual to the kernel's natural argument, ie exp(i.k.z)
238- Note, z is in grid-point units, and k values must be in [-pi, pi) for
239- accuracy.
240- opts - spreading opts object, needed to eval kernel (must be already set up)
259+ /*
260+ Evaluates the Fourier transform of the kernel at a single point.
241261
242- Outputs :
243- phihat - real Fourier transform evaluated at freqs (alloc for nk Ts )
262+ Inputs :
263+ k - frequency, dual to the kernel's natural argument, ie exp(i.k.z )
244264
245- Barnett 2/8/17. openmp since cos slow 2/9/17.
246- To do (Nov 2024): replace evaluate_kernel by evaluate_kernel_horner.
247- */
248- {
249- T J2 = opts.nspread / 2.0 ; // J/2, half-width of ker z-support
250- // # quadr nodes in z (from 0 to J/2; reflections will be added)...
251- int q = (int )(2 + 2.0 * J2); // > pi/2 ratio. cannot exceed MAX_NQUAD
252- if (opts.debug ) printf (" q (# ker FT quadr pts) = %d\n " , q);
253- T f[MAX_NQUAD];
254- double z[2 * MAX_NQUAD], w[2 * MAX_NQUAD]; // glr needs double
255- legendre_compute_glr (2 * q, z, w); // only half the nodes used, eg on (0,1)
256- for (int n = 0 ; n < q; ++n) {
257- z[n] *= (T)J2; // quadr nodes for [0,J/2]
258- f[n] = J2 * (T)w[n] * evaluate_kernel ((T)z[n], opts); // w/ quadr weights
259- }
260- #pragma omp parallel for num_threads(opts.nthreads)
261- for (BIGINT j = 0 ; j < nk; ++j) { // loop along output array
262- T x = 0.0 ; // register
263- for (int n = 0 ; n < q; ++n)
264- x += f[n] * 2 * cos (k[j] * (T)z[n]); // pos & neg freq pair. use T cos!
265- phihat[j] = x;
265+ Outputs:
266+ phihat - real Fourier transform evaluated at freq k
267+ */
268+ FINUFFT_ALWAYS_INLINE T operator ()(T k) {
269+ T x = 0 ;
270+ for (size_t n = 0 ; n < z.size (); ++n)
271+ x += f[n] * 2 * cos (k * z[n]); // pos & neg freq pair. use T cos!
272+ return x;
266273 }
267- }
274+ };
268275
269276template <typename T>
270277static void deconvolveshuffle1d (int dir, T prefac, const std::vector<T> &ker, BIGINT ms,
@@ -860,37 +867,29 @@ int FINUFFT_PLAN_T<TF>::setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF
860867 for (BIGINT j = 0 ; j < nj; ++j)
861868 prephase[j] = {1.0 , 0.0 }; // *** or keep flag so no mult in exec??
862869
863- // rescale the target s_k etc to s'_k etc...
864- #pragma omp parallel for num_threads(opts.nthreads) schedule(static)
865- for (BIGINT k = 0 ; k < nk; ++k) {
866- for (int idim = 0 ; idim < dim; ++idim)
867- STUp[idim][k] =
868- t3P.h [idim] * t3P.gam [idim] * (STU_in[idim][k] - t3P.D [idim]); // so |s'_k| <
869- // pi/R
870- }
870+ KernelFseries<TF> fseries (spopts);
871871 // (old STEP 3a) Compute deconvolution post-factors array (per targ pt)...
872872 // (exploits that FT separates because kernel is prod of 1D funcs)
873873 deconv.resize (nk);
874- std::array<std::vector<TF>, 3 > phiHatk;
875- for (int idim = 0 ; idim < dim; ++idim) {
876- phiHatk[idim].resize (nk);
877- onedim_nuft_kernel (nk, STUp[idim], phiHatk[idim], spopts); // fill phiHat1
878- }
879874 // C can be nan or inf if M=0, no input NU pts
880- int Cfinite =
875+ bool Cfinite =
881876 std::isfinite (t3P.C [0 ]) && std::isfinite (t3P.C [1 ]) && std::isfinite (t3P.C [2 ]);
882- int Cnonzero = t3P.C [0 ] != 0.0 || t3P.C [1 ] != 0.0 || t3P.C [2 ] != 0.0 ; // cen
877+ bool Cnonzero = t3P.C [0 ] != 0.0 || t3P.C [1 ] != 0.0 || t3P.C [2 ] != 0.0 ; // cen
878+ bool do_phase = Cfinite && Cnonzero;
883879#pragma omp parallel for num_threads(opts.nthreads) schedule(static)
884880 for (BIGINT k = 0 ; k < nk; ++k) { // .... loop over NU targ freqs
885881 TF phiHat = 1 ;
886- for (int idim = 0 ; idim < dim; ++idim) phiHat *= phiHatk[idim][k];
887- deconv[k] = (std::complex <TF>)(1.0 / phiHat);
888- if (Cfinite && Cnonzero) {
889- TF phase = 0 ;
890- for (int idim = 0 ; idim < dim; ++idim)
891- phase += (STU_in[idim][k] - t3P.D [idim]) * t3P.C [idim];
892- deconv[k] *= std::polar (TF (1 ), isign * phase); // Euler e^{+-i.phase}
882+ TF phase = 0 ;
883+ for (int idim = 0 ; idim < dim; ++idim) {
884+ auto tSTUin = STU_in[idim][k];
885+ // rescale the target s_k etc to s'_k etc...
886+ auto tSTUp = t3P.h [idim] * t3P.gam [idim] * (tSTUin - t3P.D [idim]); // so |s'_k| <
887+ // pi/R
888+ phiHat *= fseries (tSTUp);
889+ if (do_phase) phase += (tSTUin - t3P.D [idim]) * t3P.C [idim];
890+ STUp[idim][k] = tSTUp;
893891 }
892+ deconv[k] = do_phase ? std::polar (TF (1 ) / phiHat, isign * phase) : TF (1 ) / phiHat;
894893 }
895894 if (opts.debug )
896895 printf (" [%s t3] phase & deconv factors:\t %.3g s\n " , __func__, timer.elapsedsec ());
0 commit comments