@@ -426,10 +426,10 @@ void deconvolveshuffle3d(int dir, FLT prefac, FLT *ker1, FLT *ker2, FLT *ker3, B
426426
427427// --------- batch helper functions for t1,2 exec: ---------------------------
428428
429- int spreadinterpSortedBatch (int batchSize, FINUFFT_PLAN p, CPX *fwBatch, CPX * cBatch)
429+ int spreadinterpSortedBatch (int batchSize, FINUFFT_PLAN p, CPX *cBatch)
430430/*
431431 Spreads (or interpolates) a batch of batchSize strength vectors in cBatch
432- to (or from) the batch of fine working grids fwBatch, using the same set of
432+ to (or from) the batch of fine working grids p-> fwBatch, using the same set of
433433 (index-sorted) NU points p->X,Y,Z for each vector in the batch.
434434 The direction (spread vs interpolate) is set by p->spopts.spread_direction.
435435 Returns 0 (no error reporting for now).
@@ -448,20 +448,20 @@ int spreadinterpSortedBatch(int batchSize, FINUFFT_PLAN p, CPX *fwBatch, CPX *cB
448448#endif
449449#pragma omp parallel for num_threads(nthr_outer)
450450 for (int i = 0 ; i < batchSize; i++) {
451- CPX *fwi = fwBatch + i * p->nf ; // start of i'th fw array in wkspace
452- CPX *ci = cBatch + i * p->nj ; // start of i'th c array in cBatch
451+ CPX *fwi = p-> fwBatch + i * p->nf ; // start of i'th fw array in wkspace
452+ CPX *ci = cBatch + i * p->nj ; // start of i'th c array in cBatch
453453 spreadinterpSorted (p->sortIndices , p->nf1 , p->nf2 , p->nf3 , (FLT *)fwi, p->nj , p->X ,
454454 p->Y , p->Z , (FLT *)ci, p->spopts , p->didSort );
455455 }
456456 return 0 ;
457457}
458458
459- int deconvolveBatch (int batchSize, FINUFFT_PLAN p, CPX *fwBatch, CPX * fkBatch)
459+ int deconvolveBatch (int batchSize, FINUFFT_PLAN p, CPX *fkBatch)
460460/*
461- Type 1: deconvolves (amplifies) from each interior fw array in fwBatch
461+ Type 1: deconvolves (amplifies) from each interior fw array in p-> fwBatch
462462 into each output array fk in fkBatch.
463463 Type 2: deconvolves from user-supplied input fk to 0-padded interior fw,
464- again looping over fk in fkBatch and fw in fwBatch.
464+ again looping over fk in fkBatch and fw in p-> fwBatch.
465465 The direction (spread vs interpolate) is set by p->spopts.spread_direction.
466466 This is mostly a loop calling deconvolveshuffle?d for the needed dim batchSize
467467 times.
@@ -471,8 +471,8 @@ int deconvolveBatch(int batchSize, FINUFFT_PLAN p, CPX *fwBatch, CPX *fkBatch)
471471 // since deconvolveshuffle?d are single-thread, omp par seems to help here...
472472#pragma omp parallel for num_threads(batchSize)
473473 for (int i = 0 ; i < batchSize; i++) {
474- CPX *fwi = fwBatch + i * p->nf ; // start of i'th fw array in wkspace
475- CPX *fki = fkBatch + i * p->N ; // start of i'th fk array in fkBatch
474+ CPX *fwi = p-> fwBatch + i * p->nf ; // start of i'th fw array in wkspace
475+ CPX *fki = fkBatch + i * p->N ; // start of i'th fk array in fkBatch
476476
477477 // Call routine from common.cpp for the dim; prefactors hardcoded to 1.0...
478478 if (p->dim == 1 )
@@ -729,9 +729,12 @@ int FINUFFT_MAKEPLAN(int type, int dim, BIGINT *n_modes, int iflag, int ntrans,
729729 return FINUFFT_ERR_MAXNALLOC;
730730 }
731731
732- #ifndef FINUFFT_USE_DUCC0
733732 timer.restart ();
733+ #ifdef FINUFFT_USE_DUCC0
734+ p->fwBatch = (CPX *)malloc (p->nf * p->batchSize * sizeof (CPX)); // the big workspace
735+ #else
734736 p->fwBatch = (CPX *)FFTW_ALLOC_CPX (p->nf * p->batchSize ); // the big workspace
737+ #endif
735738 if (p->opts .debug )
736739 printf (" [%s] fwBatch %.2fGB alloc: \t %.3g s\n " , __func__,
737740 (double )1E-09 * sizeof (CPX) * p->nf * p->batchSize , timer.elapsedsec ());
@@ -744,6 +747,7 @@ int FINUFFT_MAKEPLAN(int type, int dim, BIGINT *n_modes, int iflag, int ntrans,
744747 return FINUFFT_ERR_ALLOC;
745748 }
746749
750+ #ifndef FINUFFT_USE_DUCC0
747751 timer.restart (); // plan the FFTW
748752 int *ns = gridsize_for_fft (p);
749753 // fftw_plan_many_dft args: rank, gridsize/dim, howmany, in, inembed, istride,
@@ -770,11 +774,9 @@ int FINUFFT_MAKEPLAN(int type, int dim, BIGINT *n_modes, int iflag, int ntrans,
770774 } else { // -------------------------- type 3 (no planning) ------------
771775
772776 if (p->opts .debug ) printf (" [%s] %dd%d: ntrans=%d\n " , __func__, dim, type, ntrans);
773- // in case destroy occurs before setpts, need safe dummy ptrs/plans...
774- #ifndef FINUFFT_USE_DUCC0
775- p->CpBatch = NULL ;
776- p->fwBatch = NULL ;
777- #endif
777+ // in case destroy occurs before setpts, need safe dummy ptrs/plans...
778+ p->CpBatch = NULL ;
779+ p->fwBatch = NULL ;
778780 p->Sp = NULL ;
779781 p->Tp = NULL ;
780782 p->Up = NULL ;
@@ -891,9 +893,13 @@ int FINUFFT_SETPTS(FINUFFT_PLAN p, BIGINT nj, FLT *xj, FLT *yj, FLT *zj, BIGINT
891893 __func__);
892894 return FINUFFT_ERR_MAXNALLOC;
893895 }
894- #ifndef FINUFFT_USE_DUCC0
896+ #ifdef FINUFFT_USE_DUCC0
897+ free (p->fwBatch );
898+ p->fwBatch = (CPX *)malloc (p->nf * p->batchSize * sizeof (CPX)); // maybe big workspace
899+ #else
895900 if (p->fwBatch ) FFTW_FR (p->fwBatch );
896901 p->fwBatch = (CPX *)FFTW_ALLOC_CPX (p->nf * p->batchSize ); // maybe big workspace
902+ #endif
897903
898904 // (note FFTW_ALLOC is not needed over malloc, but matches its type)
899905 if (p->CpBatch ) free (p->CpBatch );
@@ -908,7 +914,6 @@ int FINUFFT_SETPTS(FINUFFT_PLAN p, BIGINT nj, FLT *xj, FLT *yj, FLT *zj, BIGINT
908914 return FINUFFT_ERR_ALLOC;
909915 }
910916 // printf("fwbatch, cpbatch ptrs: %llx %llx\n",p->fwBatch,p->CpBatch);
911- #endif
912917
913918 // alloc rescaled NU src pts x'_j (in X etc), rescaled NU targ pts s'_k ...
914919 if (p->X ) free (p->X );
@@ -1073,13 +1078,6 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) {
10731078 CNTime timer;
10741079 timer.start ();
10751080
1076- #ifdef FINUFFT_USE_DUCC0
1077- std::vector<CPX> fwBatch_ (p->nf * p->batchSize ); // the big workspace
1078- CPX *fwBatch = fwBatch_.data ();
1079- #else
1080- CPX *fwBatch = p->fwBatch ;
1081- #endif
1082-
10831081 if (p->type != 3 ) { // --------------------- TYPE 1,2 EXEC ------------------
10841082
10851083 double t_sprint = 0.0 , t_fft = 0.0 , t_deconv = 0.0 ; // accumulated timing
@@ -1100,26 +1098,26 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) {
11001098 // STEP 1: (varies by type)
11011099 timer.restart ();
11021100 if (p->type == 1 ) { // type 1: spread NU pts p->X, weights cj, to fw grid
1103- spreadinterpSortedBatch (thisBatchSize, p, fwBatch, cjb);
1101+ spreadinterpSortedBatch (thisBatchSize, p, cjb);
11041102 t_sprint += timer.elapsedsec ();
11051103 } else { // type 2: amplify Fourier coeffs fk into 0-padded fw
1106- deconvolveBatch (thisBatchSize, p, fwBatch, fkb);
1104+ deconvolveBatch (thisBatchSize, p, fkb);
11071105 t_deconv += timer.elapsedsec ();
11081106 }
11091107
11101108 // STEP 2: call the FFT on this batch
11111109 timer.restart ();
1112- do_fft (p, fwBatch );
1110+ do_fft (p);
11131111 t_fft += timer.elapsedsec ();
11141112 if (p->opts .debug > 1 ) printf (" \t FFT exec:\t\t %.3g s\n " , timer.elapsedsec ());
11151113
11161114 // STEP 3: (varies by type)
11171115 timer.restart ();
11181116 if (p->type == 1 ) { // type 1: deconvolve (amplify) fw and shuffle to fk
1119- deconvolveBatch (thisBatchSize, p, fwBatch, fkb);
1117+ deconvolveBatch (thisBatchSize, p, fkb);
11201118 t_deconv += timer.elapsedsec ();
11211119 } else { // type 2: interpolate unif fw grid to NU target pts
1122- spreadinterpSortedBatch (thisBatchSize, p, fwBatch, cjb);
1120+ spreadinterpSortedBatch (thisBatchSize, p, cjb);
11231121 t_sprint += timer.elapsedsec ();
11241122 }
11251123 } // ........end b loop
@@ -1148,13 +1146,6 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) {
11481146 printf (" [%s t3] start ntrans=%d (%d batches, bsize=%d)...\n " , __func__, p->ntrans ,
11491147 p->nbatch , p->batchSize );
11501148
1151- #ifdef FINUFFT_USE_DUCC0
1152- std::vector<CPX> CpBatch_ (p->nj * p->batchSize ); // batch c' work
1153- CPX *CpBatch = CpBatch_.data ();
1154- #else
1155- CPX *CpBatch = p->CpBatch ;
1156- #endif
1157-
11581149 for (int b = 0 ; b * p->batchSize < p->ntrans ; b++) { // .....loop b over batches
11591150
11601151 // batching and pointers to this batch, identical to t1,2 above...
@@ -1171,14 +1162,14 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) {
11711162 for (int i = 0 ; i < thisBatchSize; i++) {
11721163 BIGINT ioff = i * p->nj ;
11731164 for (BIGINT j = 0 ; j < p->nj ; ++j)
1174- CpBatch[ioff + j] = p->prephase [j] * cjb[ioff + j];
1165+ p-> CpBatch [ioff + j] = p->prephase [j] * cjb[ioff + j];
11751166 }
11761167 t_pre += timer.elapsedsec ();
11771168
11781169 // STEP 1: spread c'_j batch (x'_j NU pts) into fw batch grid...
11791170 timer.restart ();
1180- p->spopts .spread_direction = 1 ; // spread
1181- spreadinterpSortedBatch (thisBatchSize, p, fwBatch, CpBatch); // p->X are primed
1171+ p->spopts .spread_direction = 1 ; // spread
1172+ spreadinterpSortedBatch (thisBatchSize, p, p-> CpBatch ); // p->X are primed
11821173 t_spr += timer.elapsedsec ();
11831174
11841175 // for (int j=0;j<p->nf1;++j)
@@ -1191,7 +1182,7 @@ int FINUFFT_EXECUTE(FINUFFT_PLAN p, CPX *cj, CPX *fk) {
11911182 p->innerT2plan ->ntrans = thisBatchSize; // do not try this at home!
11921183 /* (alarming that FFT not shrunk, but safe, because t2's fwBatch array
11931184 still the same size, as Andrea explained; just wastes a few flops) */
1194- FINUFFT_EXECUTE (p->innerT2plan , fkb, (CPX *)( fwBatch) );
1185+ FINUFFT_EXECUTE (p->innerT2plan , fkb, p-> fwBatch );
11951186 t_t2 += timer.elapsedsec ();
11961187
11971188 // STEP 3: apply deconvolve (precomputed 1/phiHat(targ_k), phasing too)...
@@ -1227,7 +1218,9 @@ int FINUFFT_DESTROY(FINUFFT_PLAN p)
12271218 if (!p) // NULL ptr, so not a ptr to a plan, report error
12281219 return 1 ;
12291220
1230- #ifndef FINUFFT_USE_DUCC0
1221+ #ifdef FINUFFT_USE_DUCC0
1222+ free (p->fwBatch ); // free the big FFTW (or t3 spread) working array
1223+ #else
12311224 FFTW_FR (p->fwBatch ); // free the big FFTW (or t3 spread) working array
12321225#endif
12331226 free (p->sortIndices );
@@ -1243,9 +1236,7 @@ int FINUFFT_DESTROY(FINUFFT_PLAN p)
12431236 free (p->phiHat3 );
12441237 } else { // free the stuff alloc for type 3 only
12451238 FINUFFT_DESTROY (p->innerT2plan ); // if NULL, ignore its error code
1246- #ifndef FINUFFT_USE_DUCC0
12471239 free (p->CpBatch );
1248- #endif
12491240 free (p->Sp );
12501241 free (p->Tp );
12511242 free (p->Up );
0 commit comments