Skip to content

Commit 17e8e04

Browse files
committed
allocate big arrays during plan generation unconditionally
1 parent 8b4fadf commit 17e8e04

File tree

4 files changed

+57
-70
lines changed

4 files changed

+57
-70
lines changed

include/finufft/defs.h

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -200,33 +200,31 @@ typedef struct FINUFFT_PLAN_S { // the main plan object, fully C++
200200

201201
int type; // transform type (Rokhlin naming): 1,2 or 3
202202
int dim; // overall dimension: 1,2 or 3
203-
int ntrans; // how many transforms to do at once (vector or "many" mode)
204-
BIGINT nj; // num of NU pts in type 1,2 (for type 3, num input x pts)
205-
BIGINT nk; // number of NU freq pts (type 3 only)
206-
FLT tol; // relative user tolerance
207-
int batchSize; // # strength vectors to group together for FFTW, etc
208-
int nbatch; // how many batches done to cover all ntrans vectors
203+
int ntrans; // how many transforms to do at once (vector or "many" mode)
204+
BIGINT nj; // num of NU pts in type 1,2 (for type 3, num input x pts)
205+
BIGINT nk; // number of NU freq pts (type 3 only)
206+
FLT tol; // relative user tolerance
207+
int batchSize; // # strength vectors to group together for FFTW, etc
208+
int nbatch; // how many batches done to cover all ntrans vectors
209209

210-
BIGINT ms; // number of modes in x (1) dir (historical CMCL name) = N1
211-
BIGINT mt; // number of modes in y (2) direction = N2
212-
BIGINT mu; // number of modes in z (3) direction = N3
213-
BIGINT N; // total # modes (prod of above three)
210+
BIGINT ms; // number of modes in x (1) dir (historical CMCL name) = N1
211+
BIGINT mt; // number of modes in y (2) direction = N2
212+
BIGINT mu; // number of modes in z (3) direction = N3
213+
BIGINT N; // total # modes (prod of above three)
214214

215-
BIGINT nf1; // size of internal fine grid in x (1) direction
216-
BIGINT nf2; // " y (2)
217-
BIGINT nf3; // " z (3)
218-
BIGINT nf; // total # fine grid points (product of the above three)
215+
BIGINT nf1; // size of internal fine grid in x (1) direction
216+
BIGINT nf2; // " y (2)
217+
BIGINT nf3; // " z (3)
218+
BIGINT nf; // total # fine grid points (product of the above three)
219219

220-
int fftSign; // sign in exponential for NUFFT defn, guaranteed to be +-1
220+
int fftSign; // sign in exponential for NUFFT defn, guaranteed to be +-1
221221

222-
FLT *phiHat1; // FT of kernel in t1,2, on x-axis mode grid
223-
FLT *phiHat2; // " y-axis.
224-
FLT *phiHat3; // " z-axis.
222+
FLT *phiHat1; // FT of kernel in t1,2, on x-axis mode grid
223+
FLT *phiHat2; // " y-axis.
224+
FLT *phiHat3; // " z-axis.
225225

226-
#ifndef FINUFFT_USE_DUCC0
227-
CPX *fwBatch; // (batches of) fine grid(s) for FFTW to plan
228-
// & act on. Usually the largest working array
229-
#endif
226+
CPX *fwBatch; // (batches of) fine grid(s) for FFTW to plan
227+
// & act on. Usually the largest working array
230228

231229
BIGINT *sortIndices; // precomputed NU pt permutation, speeds spread/interp
232230
bool didSort; // whether binsorting used (false: identity perm used)
@@ -238,9 +236,7 @@ typedef struct FINUFFT_PLAN_S { // the main plan object, fully C++
238236
FLT *S, *T, *U; // pointers to user's target NU pts arrays (no new allocs)
239237
CPX *prephase; // pre-phase, for all input NU pts
240238
CPX *deconv; // reciprocal of kernel FT, phase, all output NU pts
241-
#ifndef FINUFFT_USE_DUCC0
242239
CPX *CpBatch; // working array of prephased strengths
243-
#endif
244240
FLT *Sp, *Tp, *Up; // internal primed targs (s'_k, etc), allocated
245241
TYPE3PARAMS t3P; // groups together type 3 shift, scale, phase, parameters
246242
FINUFFT_PLAN innerT2plan; // ptr used for type 2 in step 2 of type 3

include/finufft/fft.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@
1313
#include <finufft/defs.h>
1414

1515
int *gridsize_for_fft(FINUFFT_PLAN p);
16-
void do_fft(FINUFFT_PLAN p, CPX *fwBatch);
16+
void do_fft(FINUFFT_PLAN p);
1717

1818
#endif // FINUFFT_INCLUDE_FINUFFT_FFT_H

src/fft.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ int *gridsize_for_fft(FINUFFT_PLAN p) {
2424
return nf;
2525
}
2626

27-
void do_fft(FINUFFT_PLAN p, CPX *fwBatch) {
27+
void do_fft(FINUFFT_PLAN p) {
2828
#ifdef FINUFFT_USE_DUCC0
2929
size_t nthreads = min<size_t>(MY_OMP_GET_MAX_THREADS(), p->opts.nthreads);
3030
int *ns = gridsize_for_fft(p);
@@ -40,7 +40,7 @@ void do_fft(FINUFFT_PLAN p, CPX *fwBatch) {
4040
arrdims.push_back(size_t(ns[2]));
4141
axes.push_back(3);
4242
}
43-
ducc0::vfmav<CPX> data(fwBatch, arrdims);
43+
ducc0::vfmav<CPX> data(p->fwBatch, arrdims);
4444
#ifdef FINUFFT_NO_DUCC0_TWEAKS
4545
ducc0::c2c(data, data, axes, p->fftSign < 0, FLT(1), nthreads);
4646
#else

src/finufft.cpp

Lines changed: 34 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -426,10 +426,10 @@ void deconvolveshuffle3d(int dir, FLT prefac, FLT *ker1, FLT *ker2, FLT *ker3, B
426426

427427
// --------- batch helper functions for t1,2 exec: ---------------------------
428428

429-
int spreadinterpSortedBatch(int batchSize, FINUFFT_PLAN p, CPX *fwBatch, CPX *cBatch)
429+
int spreadinterpSortedBatch(int batchSize, FINUFFT_PLAN p, CPX *cBatch)
430430
/*
431431
Spreads (or interpolates) a batch of batchSize strength vectors in cBatch
432-
to (or from) the batch of fine working grids fwBatch, using the same set of
432+
to (or from) the batch of fine working grids p->fwBatch, using the same set of
433433
(index-sorted) NU points p->X,Y,Z for each vector in the batch.
434434
The direction (spread vs interpolate) is set by p->spopts.spread_direction.
435435
Returns 0 (no error reporting for now).
@@ -448,20 +448,20 @@ int spreadinterpSortedBatch(int batchSize, FINUFFT_PLAN p, CPX *fwBatch, CPX *cB
448448
#endif
449449
#pragma omp parallel for num_threads(nthr_outer)
450450
for (int i = 0; i < batchSize; i++) {
451-
CPX *fwi = fwBatch + i * p->nf; // start of i'th fw array in wkspace
452-
CPX *ci = cBatch + i * p->nj; // start of i'th c array in cBatch
451+
CPX *fwi = p->fwBatch + i * p->nf; // start of i'th fw array in wkspace
452+
CPX *ci = cBatch + i * p->nj; // start of i'th c array in cBatch
453453
spreadinterpSorted(p->sortIndices, p->nf1, p->nf2, p->nf3, (FLT *)fwi, p->nj, p->X,
454454
p->Y, p->Z, (FLT *)ci, p->spopts, p->didSort);
455455
}
456456
return 0;
457457
}
458458

459-
int deconvolveBatch(int batchSize, FINUFFT_PLAN p, CPX *fwBatch, CPX *fkBatch)
459+
int deconvolveBatch(int batchSize, FINUFFT_PLAN p, CPX *fkBatch)
460460
/*
461-
Type 1: deconvolves (amplifies) from each interior fw array in fwBatch
461+
Type 1: deconvolves (amplifies) from each interior fw array in p->fwBatch
462462
into each output array fk in fkBatch.
463463
Type 2: deconvolves from user-supplied input fk to 0-padded interior fw,
464-
again looping over fk in fkBatch and fw in fwBatch.
464+
again looping over fk in fkBatch and fw in p->fwBatch.
465465
The direction (spread vs interpolate) is set by p->spopts.spread_direction.
466466
This is mostly a loop calling deconvolveshuffle?d for the needed dim batchSize
467467
times.
@@ -471,8 +471,8 @@ int deconvolveBatch(int batchSize, FINUFFT_PLAN p, CPX *fwBatch, CPX *fkBatch)
471471
// since deconvolveshuffle?d are single-thread, omp par seems to help here...
472472
#pragma omp parallel for num_threads(batchSize)
473473
for (int i = 0; i < batchSize; i++) {
474-
CPX *fwi = fwBatch + i * p->nf; // start of i'th fw array in wkspace
475-
CPX *fki = fkBatch + i * p->N; // start of i'th fk array in fkBatch
474+
CPX *fwi = p->fwBatch + i * p->nf; // start of i'th fw array in wkspace
475+
CPX *fki = fkBatch + i * p->N; // start of i'th fk array in fkBatch
476476

477477
// Call routine from common.cpp for the dim; prefactors hardcoded to 1.0...
478478
if (p->dim == 1)
@@ -729,9 +729,12 @@ int FINUFFT_MAKEPLAN(int type, int dim, BIGINT *n_modes, int iflag, int ntrans,
729729
return FINUFFT_ERR_MAXNALLOC;
730730
}
731731

732-
#ifndef FINUFFT_USE_DUCC0
733732
timer.restart();
733+
#ifdef FINUFFT_USE_DUCC0
734+
p->fwBatch = (CPX *)malloc(p->nf * p->batchSize * sizeof(CPX)); // the big workspace
735+
#else
734736
p->fwBatch = (CPX *)FFTW_ALLOC_CPX(p->nf * p->batchSize); // the big workspace
737+
#endif
735738
if (p->opts.debug)
736739
printf("[%s] fwBatch %.2fGB alloc: \t%.3g s\n", __func__,
737740
(double)1E-09 * sizeof(CPX) * p->nf * p->batchSize, timer.elapsedsec());
@@ -744,6 +747,7 @@ int FINUFFT_MAKEPLAN(int type, int dim, BIGINT *n_modes, int iflag, int ntrans,
744747
return FINUFFT_ERR_ALLOC;
745748
}
746749

750+
#ifndef FINUFFT_USE_DUCC0
747751
timer.restart(); // plan the FFTW
748752
int *ns = gridsize_for_fft(p);
749753
// fftw_plan_many_dft args: rank, gridsize/dim, howmany, in, inembed, istride,
@@ -770,11 +774,9 @@ int FINUFFT_MAKEPLAN(int type, int dim, BIGINT *n_modes, int iflag, int ntrans,
770774
} else { // -------------------------- type 3 (no planning) ------------
771775

772776
if (p->opts.debug) printf("[%s] %dd%d: ntrans=%d\n", __func__, dim, type, ntrans);
773-
// in case destroy occurs before setpts, need safe dummy ptrs/plans...
774-
#ifndef FINUFFT_USE_DUCC0
775-
p->CpBatch = NULL;
776-
p->fwBatch = NULL;
777-
#endif
777+
// in case destroy occurs before setpts, need safe dummy ptrs/plans...
778+
p->CpBatch = NULL;
779+
p->fwBatch = NULL;
778780
p->Sp = NULL;
779781
p->Tp = NULL;
780782
p->Up = NULL;
@@ -891,9 +893,13 @@ int FINUFFT_SETPTS(FINUFFT_PLAN p, BIGINT nj, FLT *xj, FLT *yj, FLT *zj, BIGINT
891893
__func__);
892894
return FINUFFT_ERR_MAXNALLOC;
893895
}
894-
#ifndef FINUFFT_USE_DUCC0
896+
#ifdef FINUFFT_USE_DUCC0
897+
free(p->fwBatch);
898+
p->fwBatch = (CPX *)malloc(p->nf * p->batchSize * sizeof(CPX)); // maybe big workspace
899+
#else
895900
if (p->fwBatch) FFTW_FR(p->fwBatch);
896901
p->fwBatch = (CPX *)FFTW_ALLOC_CPX(p->nf * p->batchSize); // maybe big workspace
902+
#endif
897903

898904
// (note FFTW_ALLOC is not needed over malloc, but matches its type)
899905
if (p->CpBatch) free(p->CpBatch);
@@ -908,7 +914,6 @@ int FINUFFT_SETPTS(FINUFFT_PLAN p, BIGINT nj, FLT *xj, FLT *yj, FLT *zj, BIGINT
908914
return FINUFFT_ERR_ALLOC;
909915
}
910916
// printf("fwbatch, cpbatch ptrs: %llx %llx\n",p->fwBatch,p->CpBatch);
911-
#endif
912917

913918
// alloc rescaled NU src pts x'_j (in X etc), rescaled NU targ pts s'_k ...
914919
if (p->X) free(p->X);
@@ -1073,13 +1078,6 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) {
10731078
CNTime timer;
10741079
timer.start();
10751080

1076-
#ifdef FINUFFT_USE_DUCC0
1077-
std::vector<CPX> fwBatch_(p->nf * p->batchSize); // the big workspace
1078-
CPX *fwBatch = fwBatch_.data();
1079-
#else
1080-
CPX *fwBatch = p->fwBatch;
1081-
#endif
1082-
10831081
if (p->type != 3) { // --------------------- TYPE 1,2 EXEC ------------------
10841082

10851083
double t_sprint = 0.0, t_fft = 0.0, t_deconv = 0.0; // accumulated timing
@@ -1100,26 +1098,26 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) {
11001098
// STEP 1: (varies by type)
11011099
timer.restart();
11021100
if (p->type == 1) { // type 1: spread NU pts p->X, weights cj, to fw grid
1103-
spreadinterpSortedBatch(thisBatchSize, p, fwBatch, cjb);
1101+
spreadinterpSortedBatch(thisBatchSize, p, cjb);
11041102
t_sprint += timer.elapsedsec();
11051103
} else { // type 2: amplify Fourier coeffs fk into 0-padded fw
1106-
deconvolveBatch(thisBatchSize, p, fwBatch, fkb);
1104+
deconvolveBatch(thisBatchSize, p, fkb);
11071105
t_deconv += timer.elapsedsec();
11081106
}
11091107

11101108
// STEP 2: call the FFT on this batch
11111109
timer.restart();
1112-
do_fft(p, fwBatch);
1110+
do_fft(p);
11131111
t_fft += timer.elapsedsec();
11141112
if (p->opts.debug > 1) printf("\tFFT exec:\t\t%.3g s\n", timer.elapsedsec());
11151113

11161114
// STEP 3: (varies by type)
11171115
timer.restart();
11181116
if (p->type == 1) { // type 1: deconvolve (amplify) fw and shuffle to fk
1119-
deconvolveBatch(thisBatchSize, p, fwBatch, fkb);
1117+
deconvolveBatch(thisBatchSize, p, fkb);
11201118
t_deconv += timer.elapsedsec();
11211119
} else { // type 2: interpolate unif fw grid to NU target pts
1122-
spreadinterpSortedBatch(thisBatchSize, p, fwBatch, cjb);
1120+
spreadinterpSortedBatch(thisBatchSize, p, cjb);
11231121
t_sprint += timer.elapsedsec();
11241122
}
11251123
} // ........end b loop
@@ -1148,13 +1146,6 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) {
11481146
printf("[%s t3] start ntrans=%d (%d batches, bsize=%d)...\n", __func__, p->ntrans,
11491147
p->nbatch, p->batchSize);
11501148

1151-
#ifdef FINUFFT_USE_DUCC0
1152-
std::vector<CPX> CpBatch_(p->nj * p->batchSize); // batch c' work
1153-
CPX *CpBatch = CpBatch_.data();
1154-
#else
1155-
CPX *CpBatch = p->CpBatch;
1156-
#endif
1157-
11581149
for (int b = 0; b * p->batchSize < p->ntrans; b++) { // .....loop b over batches
11591150

11601151
// batching and pointers to this batch, identical to t1,2 above...
@@ -1171,14 +1162,14 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) {
11711162
for (int i = 0; i < thisBatchSize; i++) {
11721163
BIGINT ioff = i * p->nj;
11731164
for (BIGINT j = 0; j < p->nj; ++j)
1174-
CpBatch[ioff + j] = p->prephase[j] * cjb[ioff + j];
1165+
p->CpBatch[ioff + j] = p->prephase[j] * cjb[ioff + j];
11751166
}
11761167
t_pre += timer.elapsedsec();
11771168

11781169
// STEP 1: spread c'_j batch (x'_j NU pts) into fw batch grid...
11791170
timer.restart();
1180-
p->spopts.spread_direction = 1; // spread
1181-
spreadinterpSortedBatch(thisBatchSize, p, fwBatch, CpBatch); // p->X are primed
1171+
p->spopts.spread_direction = 1; // spread
1172+
spreadinterpSortedBatch(thisBatchSize, p, p->CpBatch); // p->X are primed
11821173
t_spr += timer.elapsedsec();
11831174

11841175
// for (int j=0;j<p->nf1;++j)
@@ -1191,7 +1182,7 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) {
11911182
p->innerT2plan->ntrans = thisBatchSize; // do not try this at home!
11921183
/* (alarming that FFT not shrunk, but safe, because t2's fwBatch array
11931184
still the same size, as Andrea explained; just wastes a few flops) */
1194-
FINUFFT_EXECUTE(p->innerT2plan, fkb, (CPX *)(fwBatch));
1185+
FINUFFT_EXECUTE(p->innerT2plan, fkb, p->fwBatch);
11951186
t_t2 += timer.elapsedsec();
11961187

11971188
// STEP 3: apply deconvolve (precomputed 1/phiHat(targ_k), phasing too)...
@@ -1227,7 +1218,9 @@ int FINUFFT_DESTROY(FINUFFT_PLAN p)
12271218
if (!p) // NULL ptr, so not a ptr to a plan, report error
12281219
return 1;
12291220

1230-
#ifndef FINUFFT_USE_DUCC0
1221+
#ifdef FINUFFT_USE_DUCC0
1222+
free(p->fwBatch); // free the big FFTW (or t3 spread) working array
1223+
#else
12311224
FFTW_FR(p->fwBatch); // free the big FFTW (or t3 spread) working array
12321225
#endif
12331226
free(p->sortIndices);
@@ -1243,9 +1236,7 @@ int FINUFFT_DESTROY(FINUFFT_PLAN p)
12431236
free(p->phiHat3);
12441237
} else { // free the stuff alloc for type 3 only
12451238
FINUFFT_DESTROY(p->innerT2plan); // if NULL, ignore its error code
1246-
#ifndef FINUFFT_USE_DUCC0
12471239
free(p->CpBatch);
1248-
#endif
12491240
free(p->Sp);
12501241
free(p->Tp);
12511242
free(p->Up);

0 commit comments

Comments
 (0)