@@ -12,8 +12,8 @@ namespace deconvolve {
1212/* Kernel for copying fw to fk with amplication by prefac/ker */
1313// Note: assume modeord=0: CMCL-compatible mode ordering in fk (from -N/2 up
1414// to N/2-1), modeord=1: FFT-compatible mode ordering in fk (from 0 to N/2-1, then -N/2 up to -1).
15- template <typename T>
16- __global__ void deconvolve_1d (int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1, int modeord ) {
15+ template <typename T, int modeord >
16+ __global__ void deconvolve_1d (int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1) {
1717 int pivot1, w1, fwkerind1;
1818 T kervalue;
1919
@@ -34,9 +34,9 @@ __global__ void deconvolve_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex
3434 }
3535}
3636
37- template <typename T>
37+ template <typename T, int modeord >
3838__global__ void deconvolve_2d (int ms, int mt, int nf1, int nf2, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1,
39- T *fwkerhalf2, int modeord ) {
39+ T *fwkerhalf2) {
4040 int pivot1, pivot2, w1, w2, fwkerind1, fwkerind2;
4141 int k1, k2, inidx, outidx;
4242 T kervalue;
@@ -69,9 +69,9 @@ __global__ void deconvolve_2d(int ms, int mt, int nf1, int nf2, cuda_complex<T>
6969 }
7070}
7171
72- template <typename T>
72+ template <typename T, int modeord >
7373__global__ void deconvolve_3d (int ms, int mt, int mu, int nf1, int nf2, int nf3, cuda_complex<T> *fw,
74- cuda_complex<T> *fk, T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3, int modeord ) {
74+ cuda_complex<T> *fk, T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3) {
7575 int pivot1, pivot2, pivot3, w1, w2, w3, fwkerind1, fwkerind2, fwkerind3;
7676 int k1, k2, k3, inidx, outidx;
7777 T kervalue;
@@ -112,8 +112,8 @@ __global__ void deconvolve_3d(int ms, int mt, int mu, int nf1, int nf2, int nf3,
112112}
113113
114114/* Kernel for copying fk to fw with same amplication */
115- template <typename T>
116- __global__ void amplify_1d (int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1, int modeord ) {
115+ template <typename T, int modeord >
116+ __global__ void amplify_1d (int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1) {
117117 int pivot1, w1, fwkerind1;
118118 T kervalue;
119119
@@ -134,9 +134,9 @@ __global__ void amplify_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T>
134134 }
135135}
136136
137- template <typename T>
137+ template <typename T, int modeord >
138138__global__ void amplify_2d (int ms, int mt, int nf1, int nf2, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1,
139- T *fwkerhalf2, int modeord ) {
139+ T *fwkerhalf2) {
140140 int pivot1, pivot2, w1, w2, fwkerind1, fwkerind2;
141141 int k1, k2, inidx, outidx;
142142 T kervalue;
@@ -169,9 +169,9 @@ __global__ void amplify_2d(int ms, int mt, int nf1, int nf2, cuda_complex<T> *fw
169169 }
170170}
171171
172- template <typename T>
172+ template <typename T, int modeord >
173173__global__ void amplify_3d (int ms, int mt, int mu, int nf1, int nf2, int nf3, cuda_complex<T> *fw, cuda_complex<T> *fk,
174- T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3, int modeord ) {
174+ T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3) {
175175 int pivot1, pivot2, pivot3, w1, w2, w3, fwkerind1, fwkerind2, fwkerind3;
176176 int k1, k2, k3, inidx, outidx;
177177 T kervalue;
@@ -211,7 +211,7 @@ __global__ void amplify_3d(int ms, int mt, int mu, int nf1, int nf2, int nf3, cu
211211 }
212212}
213213
214- template <typename T>
214+ template <typename T, int modeord >
215215int cudeconvolve1d (cufinufft_plan_t <T> *d_plan, int blksize)
216216/*
217217 wrapper for deconvolution & amplication in 1D.
@@ -228,20 +228,20 @@ int cudeconvolve1d(cufinufft_plan_t<T> *d_plan, int blksize)
228228
229229 if (d_plan->spopts .spread_direction == 1 ) {
230230 for (int t = 0 ; t < blksize; t++) {
231- deconvolve_1d<<<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (ms, nf1, d_plan->fw + t * nf1,
232- d_plan->fk + t * nmodes, d_plan->fwkerhalf1 , d_plan-> opts . modeord );
231+ deconvolve_1d<T, modeord> < <<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (ms, nf1, d_plan->fw + t * nf1,
232+ d_plan->fk + t * nmodes, d_plan->fwkerhalf1 );
233233 }
234234 } else {
235235 checkCudaErrors (cudaMemsetAsync (d_plan->fw , 0 , maxbatchsize * nf1 * sizeof (cuda_complex<T>), stream));
236236 for (int t = 0 ; t < blksize; t++) {
237- amplify_1d<<<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (ms, nf1, d_plan->fw + t * nf1,
238- d_plan->fk + t * nmodes, d_plan->fwkerhalf1 , d_plan-> opts . modeord );
237+ amplify_1d<T, modeord> < <<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (ms, nf1, d_plan->fw + t * nf1,
238+ d_plan->fk + t * nmodes, d_plan->fwkerhalf1 );
239239 }
240240 }
241241 return 0 ;
242242}
243243
244- template <typename T>
244+ template <typename T, int modeord >
245245int cudeconvolve2d (cufinufft_plan_t <T> *d_plan, int blksize)
246246/*
247247 wrapper for deconvolution & amplication in 2D.
@@ -260,22 +260,22 @@ int cudeconvolve2d(cufinufft_plan_t<T> *d_plan, int blksize)
260260
261261 if (d_plan->spopts .spread_direction == 1 ) {
262262 for (int t = 0 ; t < blksize; t++) {
263- deconvolve_2d<<<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (ms, mt, nf1, nf2, d_plan->fw + t * nf1 * nf2,
263+ deconvolve_2d<T, modeord> < <<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (ms, mt, nf1, nf2, d_plan->fw + t * nf1 * nf2,
264264 d_plan->fk + t * nmodes, d_plan->fwkerhalf1 ,
265- d_plan->fwkerhalf2 , d_plan-> opts . modeord );
265+ d_plan->fwkerhalf2 );
266266 }
267267 } else {
268268 checkCudaErrors (cudaMemsetAsync (d_plan->fw , 0 , maxbatchsize * nf1 * nf2 * sizeof (cuda_complex<T>), stream));
269269 for (int t = 0 ; t < blksize; t++) {
270- amplify_2d<<<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (ms, mt, nf1, nf2, d_plan->fw + t * nf1 * nf2,
270+ amplify_2d<T, modeord> < <<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (ms, mt, nf1, nf2, d_plan->fw + t * nf1 * nf2,
271271 d_plan->fk + t * nmodes, d_plan->fwkerhalf1 ,
272- d_plan->fwkerhalf2 , d_plan-> opts . modeord );
272+ d_plan->fwkerhalf2 );
273273 }
274274 }
275275 return 0 ;
276276}
277277
278- template <typename T>
278+ template <typename T, int modeord >
279279int cudeconvolve3d (cufinufft_plan_t <T> *d_plan, int blksize)
280280/*
281281 wrapper for deconvolution & amplication in 3D.
@@ -295,28 +295,34 @@ int cudeconvolve3d(cufinufft_plan_t<T> *d_plan, int blksize)
295295 int maxbatchsize = d_plan->maxbatchsize ;
296296 if (d_plan->spopts .spread_direction == 1 ) {
297297 for (int t = 0 ; t < blksize; t++) {
298- deconvolve_3d<<<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (
298+ deconvolve_3d<T, modeord> < <<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (
299299 ms, mt, mu, nf1, nf2, nf3, d_plan->fw + t * nf1 * nf2 * nf3, d_plan->fk + t * nmodes,
300- d_plan->fwkerhalf1 , d_plan->fwkerhalf2 , d_plan->fwkerhalf3 , d_plan-> opts . modeord );
300+ d_plan->fwkerhalf1 , d_plan->fwkerhalf2 , d_plan->fwkerhalf3 );
301301 }
302302 } else {
303303 checkCudaErrors (
304304 cudaMemsetAsync (d_plan->fw , 0 , maxbatchsize * nf1 * nf2 * nf3 * sizeof (cuda_complex<T>), stream));
305305 for (int t = 0 ; t < blksize; t++) {
306- amplify_3d<<<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (
306+ amplify_3d<T, modeord> < <<(nmodes + 256 - 1 ) / 256 , 256 , 0 , stream>>> (
307307 ms, mt, mu, nf1, nf2, nf3, d_plan->fw + t * nf1 * nf2 * nf3, d_plan->fk + t * nmodes,
308- d_plan->fwkerhalf1 , d_plan->fwkerhalf2 , d_plan->fwkerhalf3 , d_plan-> opts . modeord );
308+ d_plan->fwkerhalf1 , d_plan->fwkerhalf2 , d_plan->fwkerhalf3 );
309309 }
310310 }
311311 return 0 ;
312312}
313313
314- template int cudeconvolve1d<float >(cufinufft_plan_t <float > *d_plan, int blksize);
315- template int cudeconvolve1d<double >(cufinufft_plan_t <double > *d_plan, int blksize);
316- template int cudeconvolve2d<float >(cufinufft_plan_t <float > *d_plan, int blksize);
317- template int cudeconvolve2d<double >(cufinufft_plan_t <double > *d_plan, int blksize);
318- template int cudeconvolve3d<float >(cufinufft_plan_t <float > *d_plan, int blksize);
319- template int cudeconvolve3d<double >(cufinufft_plan_t <double > *d_plan, int blksize);
314+ template int cudeconvolve1d<float , 0 >(cufinufft_plan_t <float > *d_plan, int blksize);
315+ template int cudeconvolve1d<float , 1 >(cufinufft_plan_t <float > *d_plan, int blksize);
316+ template int cudeconvolve1d<double , 0 >(cufinufft_plan_t <double > *d_plan, int blksize);
317+ template int cudeconvolve1d<double , 1 >(cufinufft_plan_t <double > *d_plan, int blksize);
318+ template int cudeconvolve2d<float , 0 >(cufinufft_plan_t <float > *d_plan, int blksize);
319+ template int cudeconvolve2d<float , 1 >(cufinufft_plan_t <float > *d_plan, int blksize);
320+ template int cudeconvolve2d<double , 0 >(cufinufft_plan_t <double > *d_plan, int blksize);
321+ template int cudeconvolve2d<double , 1 >(cufinufft_plan_t <double > *d_plan, int blksize);
322+ template int cudeconvolve3d<float , 0 >(cufinufft_plan_t <float > *d_plan, int blksize);
323+ template int cudeconvolve3d<float , 1 >(cufinufft_plan_t <float > *d_plan, int blksize);
324+ template int cudeconvolve3d<double , 0 >(cufinufft_plan_t <double > *d_plan, int blksize);
325+ template int cudeconvolve3d<double , 1 >(cufinufft_plan_t <double > *d_plan, int blksize);
320326
321327} // namespace deconvolve
322328} // namespace cufinufft
0 commit comments