Skip to content

Commit bb78e88

Browse files
committed
templating mordeord in cuda deconv kernels
1 parent 63623c0 commit bb78e88

File tree

5 files changed

+96
-60
lines changed

5 files changed

+96
-60
lines changed

include/cufinufft/cudeconvolve.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,29 @@
55

66
namespace cufinufft {
77
namespace deconvolve {
8-
template <typename T>
9-
__global__ void deconvolve_1d(int ms, int nf1, int fw_width, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1, int modeord);
10-
template <typename T>
11-
__global__ void amplify_1d(int ms, int nf1, int fw_width, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf2, int modeord);
12-
template <typename T>
8+
template <typename T, int modeord>
9+
__global__ void deconvolve_1d(int ms, int nf1, int fw_width, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1);
10+
template <typename T, int modeord>
11+
__global__ void amplify_1d(int ms, int nf1, int fw_width, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf2);
12+
template <typename T, int modeord>
1313
__global__ void deconvolve_2d(int ms, int mt, int nf1, int nf2, int fw_width, cuda_complex<T> *fw, cuda_complex<T> *fk,
14-
T *fwkerhalf1, T *fwkerhalf2, int modeord);
15-
template <typename T>
14+
T *fwkerhalf1, T *fwkerhalf2);
15+
template <typename T, int modeord>
1616
__global__ void amplify_2d(int ms, int mt, int nf1, int nf2, int fw_width, cuda_complex<T> *fw, cuda_complex<T> *fk,
17-
T *fwkerhalf1, T *fwkerhalf2, int modeord);
17+
T *fwkerhalf1, T *fwkerhalf2);
1818

19-
template <typename T>
19+
template <typename T, int modeord>
2020
__global__ void deconvolve_3d(int ms, int mt, int mu, int nf1, int nf2, int nf3, int fw_width, cuda_complex<T> *fw,
21-
cuda_complex<T> *fk, T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3, int modeord);
22-
template <typename T>
21+
cuda_complex<T> *fk, T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3);
22+
template <typename T, int modeord>
2323
__global__ void amplify_3d(int ms, int mt, int mu, int nf1, int nf2, int nf3, int fw_width, cuda_complex<T> *fw,
24-
cuda_complex<T> *fk, T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3, int modeord);
24+
cuda_complex<T> *fk, T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3);
2525

26-
template <typename T>
26+
template <typename T, int modeord>
2727
int cudeconvolve1d(cufinufft_plan_t<T> *d_mem, int blksize);
28-
template <typename T>
28+
template <typename T, int modeord>
2929
int cudeconvolve2d(cufinufft_plan_t<T> *d_mem, int blksize);
30-
template <typename T>
30+
template <typename T, int modeord>
3131
int cudeconvolve3d(cufinufft_plan_t<T> *d_mem, int blksize);
3232
} // namespace deconvolve
3333
} // namespace cufinufft

src/cuda/1d/cufinufft1d.cu

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,13 @@ int cufinufft1d1_exec(cuda_complex<T> *d_c, cuda_complex<T> *d_fk, cufinufft_pla
5959
return FINUFFT_ERR_CUDA_FAILURE;
6060

6161
// Step 3: deconvolve and shuffle
62-
if ((ier = cudeconvolve1d<T>(d_plan, blksize)))
63-
return ier;
62+
if (d_plan->opts.modeord == 0) {
63+
if ((ier = cudeconvolve1d<T, 0>(d_plan, blksize)))
64+
return ier;
65+
} else {
66+
if ((ier = cudeconvolve1d<T, 1>(d_plan, blksize)))
67+
return ier;
68+
}
6469
}
6570

6671
return 0;
@@ -95,8 +100,13 @@ int cufinufft1d2_exec(cuda_complex<T> *d_c, cuda_complex<T> *d_fk, cufinufft_pla
95100
d_plan->fk = d_fkstart;
96101

97102
// Step 1: amplify Fourier coeffs fk and copy into upsampled array fw
98-
if ((ier = cudeconvolve1d<T>(d_plan, blksize)))
99-
return ier;
103+
if (d_plan->opts.modeord == 0) {
104+
if ((ier = cudeconvolve1d<T, 0>(d_plan, blksize)))
105+
return ier;
106+
} else {
107+
if ((ier = cudeconvolve1d<T, 1>(d_plan, blksize)))
108+
return ier;
109+
}
100110

101111
// Step 2: FFT
102112
cufftResult cufft_status = cufft_ex(d_plan->fftplan, d_plan->fw, d_plan->fw, d_plan->iflag);

src/cuda/2d/cufinufft2d.cu

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,13 @@ int cufinufft2d1_exec(cuda_complex<T> *d_c, cuda_complex<T> *d_fk, cufinufft_pla
5959
return FINUFFT_ERR_CUDA_FAILURE;
6060

6161
// Step 3: deconvolve and shuffle
62-
if ((ier = cudeconvolve2d<T>(d_plan, blksize)))
63-
return ier;
62+
if (d_plan->opts.modeord == 0) {
63+
if ((ier = cudeconvolve2d<T, 0>(d_plan, blksize)))
64+
return ier;
65+
} else {
66+
if ((ier = cudeconvolve2d<T, 1>(d_plan, blksize)))
67+
return ier;
68+
}
6469
}
6570

6671
return 0;
@@ -95,8 +100,13 @@ int cufinufft2d2_exec(cuda_complex<T> *d_c, cuda_complex<T> *d_fk, cufinufft_pla
95100
d_plan->fk = d_fkstart;
96101

97102
// Step 1: amplify Fourier coeffs fk and copy into upsampled array fw
98-
if ((ier = cudeconvolve2d<T>(d_plan, blksize)))
99-
return ier;
103+
if (d_plan->opts.modeord == 0) {
104+
if ((ier = cudeconvolve2d<T, 0>(d_plan, blksize)))
105+
return ier;
106+
} else {
107+
if ((ier = cudeconvolve2d<T, 1>(d_plan, blksize)))
108+
return ier;
109+
}
100110

101111
// Step 2: FFT
102112
cufftResult cufft_status = cufft_ex(d_plan->fftplan, d_plan->fw, d_plan->fw, d_plan->iflag);

src/cuda/3d/cufinufft3d.cu

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,13 @@ int cufinufft3d1_exec(cuda_complex<T> *d_c, cuda_complex<T> *d_fk, cufinufft_pla
5757
return FINUFFT_ERR_CUDA_FAILURE;
5858

5959
// Step 3: deconvolve and shuffle
60-
if ((ier = cudeconvolve3d<T>(d_plan, blksize)))
61-
return ier;
60+
if (d_plan->opts.modeord == 0) {
61+
if ((ier = cudeconvolve3d<T, 0>(d_plan, blksize)))
62+
return ier;
63+
} else {
64+
if ((ier = cudeconvolve3d<T, 1>(d_plan, blksize)))
65+
return ier;
66+
}
6267
}
6368

6469
return 0;
@@ -91,8 +96,13 @@ int cufinufft3d2_exec(cuda_complex<T> *d_c, cuda_complex<T> *d_fk, cufinufft_pla
9196
d_plan->fk = d_fkstart;
9297

9398
// Step 1: amplify Fourier coeffs fk and copy into upsampled array fw
94-
if ((ier = cudeconvolve3d<T>(d_plan, blksize)))
95-
return ier;
99+
if (d_plan->opts.modeord == 0) {
100+
if ((ier = cudeconvolve3d<T, 0>(d_plan, blksize)))
101+
return ier;
102+
} else {
103+
if ((ier = cudeconvolve3d<T, 1>(d_plan, blksize)))
104+
return ier;
105+
}
96106

97107
// Step 2: FFT
98108
RETURN_IF_CUDA_ERROR

src/cuda/deconvolve_wrapper.cu

Lines changed: 39 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ namespace deconvolve {
1212
/* Kernel for copying fw to fk with amplication by prefac/ker */
1313
// Note: assume modeord=0: CMCL-compatible mode ordering in fk (from -N/2 up
1414
// to N/2-1), modeord=1: FFT-compatible mode ordering in fk (from 0 to N/2-1, then -N/2 up to -1).
15-
template <typename T>
16-
__global__ void deconvolve_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1, int modeord) {
15+
template <typename T, int modeord>
16+
__global__ void deconvolve_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1) {
1717
int pivot1, w1, fwkerind1;
1818
T kervalue;
1919

@@ -34,9 +34,9 @@ __global__ void deconvolve_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex
3434
}
3535
}
3636

37-
template <typename T>
37+
template <typename T, int modeord>
3838
__global__ void deconvolve_2d(int ms, int mt, int nf1, int nf2, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1,
39-
T *fwkerhalf2, int modeord) {
39+
T *fwkerhalf2) {
4040
int pivot1, pivot2, w1, w2, fwkerind1, fwkerind2;
4141
int k1, k2, inidx, outidx;
4242
T kervalue;
@@ -69,9 +69,9 @@ __global__ void deconvolve_2d(int ms, int mt, int nf1, int nf2, cuda_complex<T>
6969
}
7070
}
7171

72-
template <typename T>
72+
template <typename T, int modeord>
7373
__global__ void deconvolve_3d(int ms, int mt, int mu, int nf1, int nf2, int nf3, cuda_complex<T> *fw,
74-
cuda_complex<T> *fk, T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3, int modeord) {
74+
cuda_complex<T> *fk, T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3) {
7575
int pivot1, pivot2, pivot3, w1, w2, w3, fwkerind1, fwkerind2, fwkerind3;
7676
int k1, k2, k3, inidx, outidx;
7777
T kervalue;
@@ -112,8 +112,8 @@ __global__ void deconvolve_3d(int ms, int mt, int mu, int nf1, int nf2, int nf3,
112112
}
113113

114114
/* Kernel for copying fk to fw with same amplication */
115-
template <typename T>
116-
__global__ void amplify_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1, int modeord) {
115+
template <typename T, int modeord>
116+
__global__ void amplify_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1) {
117117
int pivot1, w1, fwkerind1;
118118
T kervalue;
119119

@@ -134,9 +134,9 @@ __global__ void amplify_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T>
134134
}
135135
}
136136

137-
template <typename T>
137+
template <typename T, int modeord>
138138
__global__ void amplify_2d(int ms, int mt, int nf1, int nf2, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1,
139-
T *fwkerhalf2, int modeord) {
139+
T *fwkerhalf2) {
140140
int pivot1, pivot2, w1, w2, fwkerind1, fwkerind2;
141141
int k1, k2, inidx, outidx;
142142
T kervalue;
@@ -169,9 +169,9 @@ __global__ void amplify_2d(int ms, int mt, int nf1, int nf2, cuda_complex<T> *fw
169169
}
170170
}
171171

172-
template <typename T>
172+
template <typename T, int modeord>
173173
__global__ void amplify_3d(int ms, int mt, int mu, int nf1, int nf2, int nf3, cuda_complex<T> *fw, cuda_complex<T> *fk,
174-
T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3, int modeord) {
174+
T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3) {
175175
int pivot1, pivot2, pivot3, w1, w2, w3, fwkerind1, fwkerind2, fwkerind3;
176176
int k1, k2, k3, inidx, outidx;
177177
T kervalue;
@@ -211,7 +211,7 @@ __global__ void amplify_3d(int ms, int mt, int mu, int nf1, int nf2, int nf3, cu
211211
}
212212
}
213213

214-
template <typename T>
214+
template <typename T, int modeord>
215215
int cudeconvolve1d(cufinufft_plan_t<T> *d_plan, int blksize)
216216
/*
217217
wrapper for deconvolution & amplication in 1D.
@@ -228,20 +228,20 @@ int cudeconvolve1d(cufinufft_plan_t<T> *d_plan, int blksize)
228228

229229
if (d_plan->spopts.spread_direction == 1) {
230230
for (int t = 0; t < blksize; t++) {
231-
deconvolve_1d<<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(ms, nf1, d_plan->fw + t * nf1,
232-
d_plan->fk + t * nmodes, d_plan->fwkerhalf1, d_plan->opts.modeord);
231+
deconvolve_1d<T, modeord><<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(ms, nf1, d_plan->fw + t * nf1,
232+
d_plan->fk + t * nmodes, d_plan->fwkerhalf1);
233233
}
234234
} else {
235235
checkCudaErrors(cudaMemsetAsync(d_plan->fw, 0, maxbatchsize * nf1 * sizeof(cuda_complex<T>), stream));
236236
for (int t = 0; t < blksize; t++) {
237-
amplify_1d<<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(ms, nf1, d_plan->fw + t * nf1,
238-
d_plan->fk + t * nmodes, d_plan->fwkerhalf1, d_plan->opts.modeord);
237+
amplify_1d<T, modeord><<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(ms, nf1, d_plan->fw + t * nf1,
238+
d_plan->fk + t * nmodes, d_plan->fwkerhalf1);
239239
}
240240
}
241241
return 0;
242242
}
243243

244-
template <typename T>
244+
template <typename T, int modeord>
245245
int cudeconvolve2d(cufinufft_plan_t<T> *d_plan, int blksize)
246246
/*
247247
wrapper for deconvolution & amplication in 2D.
@@ -260,22 +260,22 @@ int cudeconvolve2d(cufinufft_plan_t<T> *d_plan, int blksize)
260260

261261
if (d_plan->spopts.spread_direction == 1) {
262262
for (int t = 0; t < blksize; t++) {
263-
deconvolve_2d<<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(ms, mt, nf1, nf2, d_plan->fw + t * nf1 * nf2,
263+
deconvolve_2d<T, modeord><<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(ms, mt, nf1, nf2, d_plan->fw + t * nf1 * nf2,
264264
d_plan->fk + t * nmodes, d_plan->fwkerhalf1,
265-
d_plan->fwkerhalf2, d_plan->opts.modeord);
265+
d_plan->fwkerhalf2);
266266
}
267267
} else {
268268
checkCudaErrors(cudaMemsetAsync(d_plan->fw, 0, maxbatchsize * nf1 * nf2 * sizeof(cuda_complex<T>), stream));
269269
for (int t = 0; t < blksize; t++) {
270-
amplify_2d<<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(ms, mt, nf1, nf2, d_plan->fw + t * nf1 * nf2,
270+
amplify_2d<T, modeord><<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(ms, mt, nf1, nf2, d_plan->fw + t * nf1 * nf2,
271271
d_plan->fk + t * nmodes, d_plan->fwkerhalf1,
272-
d_plan->fwkerhalf2, d_plan->opts.modeord);
272+
d_plan->fwkerhalf2);
273273
}
274274
}
275275
return 0;
276276
}
277277

278-
template <typename T>
278+
template <typename T, int modeord>
279279
int cudeconvolve3d(cufinufft_plan_t<T> *d_plan, int blksize)
280280
/*
281281
wrapper for deconvolution & amplication in 3D.
@@ -295,28 +295,34 @@ int cudeconvolve3d(cufinufft_plan_t<T> *d_plan, int blksize)
295295
int maxbatchsize = d_plan->maxbatchsize;
296296
if (d_plan->spopts.spread_direction == 1) {
297297
for (int t = 0; t < blksize; t++) {
298-
deconvolve_3d<<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(
298+
deconvolve_3d<T, modeord><<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(
299299
ms, mt, mu, nf1, nf2, nf3, d_plan->fw + t * nf1 * nf2 * nf3, d_plan->fk + t * nmodes,
300-
d_plan->fwkerhalf1, d_plan->fwkerhalf2, d_plan->fwkerhalf3, d_plan->opts.modeord);
300+
d_plan->fwkerhalf1, d_plan->fwkerhalf2, d_plan->fwkerhalf3);
301301
}
302302
} else {
303303
checkCudaErrors(
304304
cudaMemsetAsync(d_plan->fw, 0, maxbatchsize * nf1 * nf2 * nf3 * sizeof(cuda_complex<T>), stream));
305305
for (int t = 0; t < blksize; t++) {
306-
amplify_3d<<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(
306+
amplify_3d<T, modeord><<<(nmodes + 256 - 1) / 256, 256, 0, stream>>>(
307307
ms, mt, mu, nf1, nf2, nf3, d_plan->fw + t * nf1 * nf2 * nf3, d_plan->fk + t * nmodes,
308-
d_plan->fwkerhalf1, d_plan->fwkerhalf2, d_plan->fwkerhalf3, d_plan->opts.modeord);
308+
d_plan->fwkerhalf1, d_plan->fwkerhalf2, d_plan->fwkerhalf3);
309309
}
310310
}
311311
return 0;
312312
}
313313

314-
template int cudeconvolve1d<float>(cufinufft_plan_t<float> *d_plan, int blksize);
315-
template int cudeconvolve1d<double>(cufinufft_plan_t<double> *d_plan, int blksize);
316-
template int cudeconvolve2d<float>(cufinufft_plan_t<float> *d_plan, int blksize);
317-
template int cudeconvolve2d<double>(cufinufft_plan_t<double> *d_plan, int blksize);
318-
template int cudeconvolve3d<float>(cufinufft_plan_t<float> *d_plan, int blksize);
319-
template int cudeconvolve3d<double>(cufinufft_plan_t<double> *d_plan, int blksize);
314+
template int cudeconvolve1d<float, 0>(cufinufft_plan_t<float> *d_plan, int blksize);
315+
template int cudeconvolve1d<float, 1>(cufinufft_plan_t<float> *d_plan, int blksize);
316+
template int cudeconvolve1d<double, 0>(cufinufft_plan_t<double> *d_plan, int blksize);
317+
template int cudeconvolve1d<double, 1>(cufinufft_plan_t<double> *d_plan, int blksize);
318+
template int cudeconvolve2d<float, 0>(cufinufft_plan_t<float> *d_plan, int blksize);
319+
template int cudeconvolve2d<float, 1>(cufinufft_plan_t<float> *d_plan, int blksize);
320+
template int cudeconvolve2d<double, 0>(cufinufft_plan_t<double> *d_plan, int blksize);
321+
template int cudeconvolve2d<double, 1>(cufinufft_plan_t<double> *d_plan, int blksize);
322+
template int cudeconvolve3d<float, 0>(cufinufft_plan_t<float> *d_plan, int blksize);
323+
template int cudeconvolve3d<float, 1>(cufinufft_plan_t<float> *d_plan, int blksize);
324+
template int cudeconvolve3d<double, 0>(cufinufft_plan_t<double> *d_plan, int blksize);
325+
template int cudeconvolve3d<double, 1>(cufinufft_plan_t<double> *d_plan, int blksize);
320326

321327
} // namespace deconvolve
322328
} // namespace cufinufft

0 commit comments

Comments
 (0)