Skip to content

Commit c0845c9

Browse files
committed
rework nthr logic replacing #431: no cap on user o.nthreads override; ifndef _OPENMP hard nthr=1 override; warnings at the finufft level only; binsort controlled by spopts.sort_threads matching docs
1 parent cc8629f commit c0845c9

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

CHANGELOG

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
List of features / changes made / release notes, in reverse chronological order.
22
If not stated, FINUFFT is assumed (cuFINUFFT <=1.3 is listed separately).
33

4-
* CPU plan stage prevents now caps # threads at omp_get_max_threads (being 1
5-
for single-thread build); warns if this cap was activated (PR 431)
4+
* CPU plan stage allows any # threads, warns if > omp_get_max_threads(); or
5+
if single-threaded fixes nthr=1 and warns opts.nthreads>1 attempt.
6+
Sort now respects spread_opts.sort_threads not nthreads. Supercedes PR 431.
67
* new docs troubleshooting accuracy limitations due to condition number of the
78
NUFFT problem.
89
* new sanity check on nj and nk (<0 or too big); new err code, tester, doc.

src/finufft.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -591,15 +591,22 @@ int FINUFFT_MAKEPLAN(int type, int dim, BIGINT* n_modes, int iflag,
591591
p->fftSign = (iflag>=0) ? 1 : -1; // clean up flag input
592592

593593
// choose overall # threads...
594-
int maxnthr = MY_OMP_GET_MAX_THREADS();
595-
int nthr = maxnthr; // use as many as OMP gives us
594+
#ifdef _OPENMP
595+
int ompmaxnthr = MY_OMP_GET_MAX_THREADS();
596+
int nthr = ompmaxnthr; // default: use as many as OMP gives us
597+
// (the above could be set, or suggested set, to 1 for small enough problems...)
596598
if (p->opts.nthreads>0) {
597-
nthr = min(maxnthr,p->opts.nthreads); // user override up to max avail
598-
if (p->opts.nthreads > maxnthr) // if no OMP, maxnthr=1
599-
fprintf(stderr,"%s warning: user requested %d threads, but only %d threads available; enforcing nthreads=%d.\n",__func__,p->opts.nthreads,maxnthr,nthr);
599+
nthr = p->opts.nthreads; // user override, now without limit
600+
if (p->opts.showwarn && (nthr > ompmaxnthr))
601+
fprintf(stderr,"%s warning: using opts.nthreads=%d, more than the %d OpenMP claims available; note large nthreads can be slower.\n",__func__,nthr,ompmaxnthr);
600602
}
603+
#else
604+
int nthr = 1; // always 1 thread (avoid segfault)
605+
if (p->opts.nthreads>1)
606+
fprintf(stderr,"%s warning: opts.nthreads=%d but library is single-threaded; ignoring!\n",__func__,p->opts.nthreads);
607+
#endif
601608
p->opts.nthreads = nthr; // store actual # thr planned for
602-
// (this sets all downstream spread/interp, 1dkernel, and FFT thread counts)
609+
// (this sets/limits all downstream spread/interp, 1dkernel, and FFT thread counts...)
603610

604611
// choose batchSize for types 1,2 or 3... (uses int ceil(b/a)=1+(b-1)/a trick)
605612
if (p->opts.maxbatchsize==0) { // logic to auto-set best batchsize

src/spreadinterp.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -267,14 +267,18 @@ int indexSort(BIGINT* sort_indices, BIGINT N1, BIGINT N2, BIGINT N3, BIGINT M,
267267
timer.start(); // if needed, sort all the NU pts...
268268
int did_sort=0;
269269
int maxnthr = MY_OMP_GET_MAX_THREADS();
270-
if (opts.nthreads>0) // user override up to max avail
271-
maxnthr = min(maxnthr,opts.nthreads);
272-
270+
if (opts.sort_threads>0) // user override, now without limit
271+
maxnthr = opts.sort_threads; // maxnthr = the max threads sorting could use
272+
// (we don't print warning here, since: no showwarn in spread_opts, and finufft
273+
// already warned about it. spreadinterp-only advanced users will miss a warning)
273274
if (opts.sort==1 || (opts.sort==2 && better_to_sort)) {
274275
// store a good permutation ordering of all NU pts (dim=1,2 or 3)
275276
int sort_debug = (opts.debug>=2); // show timing output?
276-
int sort_nthr = opts.sort_threads; // choose # threads for sorting
277-
if (sort_nthr==0) // use auto choice: when N>>M, one thread is better!
277+
int sort_nthr = opts.sort_threads; // 0, or proposed max # threads for sorting
278+
#ifndef _OPENMP
279+
sort_nthr = 1; // if single-threaded lib, override user
280+
#endif
281+
if (sort_nthr==0) // multithreaded auto choice: when N>>M, one thread is better!
278282
sort_nthr = (10*M>N) ? maxnthr : 1; // heuristic
279283
if (sort_nthr==1)
280284
bin_sort_singlethread(sort_indices,M,kx,ky,kz,N1,N2,N3,opts.pirange,bin_size_x,bin_size_y,bin_size_z,sort_debug);
@@ -323,9 +327,12 @@ int spreadSorted(BIGINT* sort_indices,BIGINT N1, BIGINT N2, BIGINT N3,
323327
int ndims = ndims_from_Ns(N1,N2,N3);
324328
BIGINT N=N1*N2*N3; // output array size
325329
int ns=opts.nspread; // abbrev. for w, kernel width
326-
int nthr = MY_OMP_GET_MAX_THREADS(); // # threads to use to spread
330+
int nthr = MY_OMP_GET_MAX_THREADS(); // guess # threads to use to spread
327331
if (opts.nthreads>0)
328-
nthr = min(nthr,opts.nthreads); // user override up to max avail
332+
nthr = opts.nthreads; // user override, now without limit
333+
#ifndef _OPENMP
334+
nthr = 1; // if single-threaded lib, override user
335+
#endif
329336
if (opts.debug)
330337
printf("\tspread %dD (M=%lld; N1=%lld,N2=%lld,N3=%lld; pir=%d), nthr=%d\n",ndims,(long long)M,(long long)N1,(long long)N2,(long long)N3,opts.pirange,nthr);
331338

@@ -445,9 +452,12 @@ int interpSorted(BIGINT* sort_indices,BIGINT N1, BIGINT N2, BIGINT N3,
445452
int ndims = ndims_from_Ns(N1,N2,N3);
446453
int ns=opts.nspread; // abbrev. for w, kernel width
447454
FLT ns2 = (FLT)ns/2; // half spread width, used as stencil shift
448-
int nthr = MY_OMP_GET_MAX_THREADS(); // # threads to use to interp
455+
int nthr = MY_OMP_GET_MAX_THREADS(); // guess # threads to use to interp
449456
if (opts.nthreads>0)
450-
nthr = min(nthr,opts.nthreads); // user override up to max avail
457+
nthr = opts.nthreads; // user override, now without limit
458+
#ifndef _OPENMP
459+
nthr = 1; // if single-threaded lib, override user
460+
#endif
451461
if (opts.debug)
452462
printf("\tinterp %dD (M=%lld; N1=%lld,N2=%lld,N3=%lld; pir=%d), nthr=%d\n",ndims,(long long)M,(long long)N1,(long long)N2,(long long)N3,opts.pirange,nthr);
453463

@@ -1292,7 +1302,7 @@ void bin_sort_multithread(BIGINT *ret, BIGINT M, FLT *kx, FLT *ky, FLT *kz,
12921302
nbins2 = isky ? N2/bin_size_y+1 : 1;
12931303
nbins3 = iskz ? N3/bin_size_z+1 : 1;
12941304
BIGINT nbins = nbins1*nbins2*nbins3;
1295-
if (nthr==0)
1305+
if (nthr==0) // should never happen in spreadinterp use
12961306
fprintf(stderr,"[%s] nthr (%d) must be positive!\n",__func__,nthr);
12971307
int nt = min(M,(BIGINT)nthr); // handle case of less points than threads
12981308
std::vector<BIGINT> brk(nt+1); // list of start NU pt indices per thread

0 commit comments

Comments
 (0)