Skip to content

Commit 8baf39d

Browse files
authored
Tweak type 3 setpts (#609)
* reduce memory overhead in Type 3 setpts by eliminating temporary phihatk arrays; reduce large array reads/writes * dummy commit to trigger formatter * fix typo * use FINUFFT_ALWAYS_INLNE
1 parent cc897a5 commit 8baf39d

File tree

1 file changed

+60
-61
lines changed

1 file changed

+60
-61
lines changed

src/finufft_core.cpp

Lines changed: 60 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -222,49 +222,56 @@ static void onedim_fseries_kernel(BIGINT nf, std::vector<T> &fwkerhalf,
222222
}
223223
}
224224

225-
template<typename T>
226-
static void onedim_nuft_kernel(BIGINT nk, const std::vector<T> &k, std::vector<T> &phihat,
227-
const finufft_spread_opts &opts)
228-
/*
229-
Approximates exact 1D Fourier transform of cnufftspread's real symmetric
230-
kernel, directly via q-node quadrature on Euler-Fourier formula, exploiting
231-
narrowness of kernel. Evaluates at set of arbitrary freqs k in [-pi, pi),
232-
for a kernel with x measured in grid-spacings. (See previous routine for
233-
FT definition).
225+
template<typename T> class KernelFseries {
226+
private:
227+
std::vector<T> z, f;
228+
229+
public:
230+
/*
231+
Approximates exact 1D Fourier transform of cnufftspread's real symmetric
232+
kernel, directly via q-node quadrature on Euler-Fourier formula, exploiting
233+
narrowness of kernel. Evaluates at set of arbitrary freqs k in [-pi, pi),
234+
for a kernel with x measured in grid-spacings. (See previous routine for
235+
FT definition).
236+
237+
Inputs:
238+
opts - spreading opts object, needed to eval kernel (must be already set up)
239+
240+
Barnett 2/8/17. openmp since cos slow 2/9/17.
241+
To do (Nov 2024): replace evaluate_kernel by evaluate_kernel_horner.
242+
*/
243+
KernelFseries(const finufft_spread_opts &opts) {
244+
T J2 = opts.nspread / 2.0; // J/2, half-width of ker z-support
245+
// # quadr nodes in z (from 0 to J/2; reflections will be added)...
246+
int q = (int)(2 + 2.0 * J2); // > pi/2 ratio. cannot exceed MAX_NQUAD
247+
if (opts.debug) printf("q (# ker FT quadr pts) = %d\n", q);
248+
std::vector<double> Z(2 * q), W(2 * q);
249+
legendre_compute_glr(2 * q, Z.data(), W.data()); // only half the nodes used, eg on
250+
// (0,1)
251+
z.resize(q);
252+
f.resize(q);
253+
for (int n = 0; n < q; ++n) {
254+
z[n] = T(Z[n] * J2); // quadr nodes for [0,J/2]
255+
f[n] = J2 * T(W[n]) * evaluate_kernel(z[n], opts); // w/ quadr weights
256+
}
257+
}
234258

235-
Inputs:
236-
nk - number of freqs
237-
k - frequencies, dual to the kernel's natural argument, ie exp(i.k.z)
238-
Note, z is in grid-point units, and k values must be in [-pi, pi) for
239-
accuracy.
240-
opts - spreading opts object, needed to eval kernel (must be already set up)
259+
/*
260+
Evaluates the Fourier transform of the kernel at a single point.
241261
242-
Outputs:
243-
phihat - real Fourier transform evaluated at freqs (alloc for nk Ts)
262+
Inputs:
263+
k - frequency, dual to the kernel's natural argument, ie exp(i.k.z)
244264
245-
Barnett 2/8/17. openmp since cos slow 2/9/17.
246-
To do (Nov 2024): replace evaluate_kernel by evaluate_kernel_horner.
247-
*/
248-
{
249-
T J2 = opts.nspread / 2.0; // J/2, half-width of ker z-support
250-
// # quadr nodes in z (from 0 to J/2; reflections will be added)...
251-
int q = (int)(2 + 2.0 * J2); // > pi/2 ratio. cannot exceed MAX_NQUAD
252-
if (opts.debug) printf("q (# ker FT quadr pts) = %d\n", q);
253-
T f[MAX_NQUAD];
254-
double z[2 * MAX_NQUAD], w[2 * MAX_NQUAD]; // glr needs double
255-
legendre_compute_glr(2 * q, z, w); // only half the nodes used, eg on (0,1)
256-
for (int n = 0; n < q; ++n) {
257-
z[n] *= (T)J2; // quadr nodes for [0,J/2]
258-
f[n] = J2 * (T)w[n] * evaluate_kernel((T)z[n], opts); // w/ quadr weights
259-
}
260-
#pragma omp parallel for num_threads(opts.nthreads)
261-
for (BIGINT j = 0; j < nk; ++j) { // loop along output array
262-
T x = 0.0; // register
263-
for (int n = 0; n < q; ++n)
264-
x += f[n] * 2 * cos(k[j] * (T)z[n]); // pos & neg freq pair. use T cos!
265-
phihat[j] = x;
265+
Outputs:
266+
phihat - real Fourier transform evaluated at freq k
267+
*/
268+
FINUFFT_ALWAYS_INLINE T operator()(T k) {
269+
T x = 0;
270+
for (size_t n = 0; n < z.size(); ++n)
271+
x += f[n] * 2 * cos(k * z[n]); // pos & neg freq pair. use T cos!
272+
return x;
266273
}
267-
}
274+
};
268275

269276
template<typename T>
270277
static void deconvolveshuffle1d(int dir, T prefac, const std::vector<T> &ker, BIGINT ms,
@@ -860,37 +867,29 @@ int FINUFFT_PLAN_T<TF>::setpts(BIGINT nj, TF *xj, TF *yj, TF *zj, BIGINT nk, TF
860867
for (BIGINT j = 0; j < nj; ++j)
861868
prephase[j] = {1.0, 0.0}; // *** or keep flag so no mult in exec??
862869

863-
// rescale the target s_k etc to s'_k etc...
864-
#pragma omp parallel for num_threads(opts.nthreads) schedule(static)
865-
for (BIGINT k = 0; k < nk; ++k) {
866-
for (int idim = 0; idim < dim; ++idim)
867-
STUp[idim][k] =
868-
t3P.h[idim] * t3P.gam[idim] * (STU_in[idim][k] - t3P.D[idim]); // so |s'_k| <
869-
// pi/R
870-
}
870+
KernelFseries<TF> fseries(spopts);
871871
// (old STEP 3a) Compute deconvolution post-factors array (per targ pt)...
872872
// (exploits that FT separates because kernel is prod of 1D funcs)
873873
deconv.resize(nk);
874-
std::array<std::vector<TF>, 3> phiHatk;
875-
for (int idim = 0; idim < dim; ++idim) {
876-
phiHatk[idim].resize(nk);
877-
onedim_nuft_kernel(nk, STUp[idim], phiHatk[idim], spopts); // fill phiHat1
878-
}
879874
// C can be nan or inf if M=0, no input NU pts
880-
int Cfinite =
875+
bool Cfinite =
881876
std::isfinite(t3P.C[0]) && std::isfinite(t3P.C[1]) && std::isfinite(t3P.C[2]);
882-
int Cnonzero = t3P.C[0] != 0.0 || t3P.C[1] != 0.0 || t3P.C[2] != 0.0; // cen
877+
bool Cnonzero = t3P.C[0] != 0.0 || t3P.C[1] != 0.0 || t3P.C[2] != 0.0; // cen
878+
bool do_phase = Cfinite && Cnonzero;
883879
#pragma omp parallel for num_threads(opts.nthreads) schedule(static)
884880
for (BIGINT k = 0; k < nk; ++k) { // .... loop over NU targ freqs
885881
TF phiHat = 1;
886-
for (int idim = 0; idim < dim; ++idim) phiHat *= phiHatk[idim][k];
887-
deconv[k] = (std::complex<TF>)(1.0 / phiHat);
888-
if (Cfinite && Cnonzero) {
889-
TF phase = 0;
890-
for (int idim = 0; idim < dim; ++idim)
891-
phase += (STU_in[idim][k] - t3P.D[idim]) * t3P.C[idim];
892-
deconv[k] *= std::polar(TF(1), isign * phase); // Euler e^{+-i.phase}
882+
TF phase = 0;
883+
for (int idim = 0; idim < dim; ++idim) {
884+
auto tSTUin = STU_in[idim][k];
885+
// rescale the target s_k etc to s'_k etc...
886+
auto tSTUp = t3P.h[idim] * t3P.gam[idim] * (tSTUin - t3P.D[idim]); // so |s'_k| <
887+
// pi/R
888+
phiHat *= fseries(tSTUp);
889+
if (do_phase) phase += (tSTUin - t3P.D[idim]) * t3P.C[idim];
890+
STUp[idim][k] = tSTUp;
893891
}
892+
deconv[k] = do_phase ? std::polar(TF(1) / phiHat, isign * phase) : TF(1) / phiHat;
894893
}
895894
if (opts.debug)
896895
printf("[%s t3] phase & deconv factors:\t%.3g s\n", __func__, timer.elapsedsec());

0 commit comments

Comments
 (0)