Skip to content

Commit 63623c0

Browse files
committed
more human readable code for cuda kernels in deconvolve_wrapper.cu
1 parent 67fa61e commit 63623c0

File tree

1 file changed

+144
-40
lines changed

1 file changed

+144
-40
lines changed

src/cuda/deconvolve_wrapper.cu

Lines changed: 144 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,21 @@ namespace deconvolve {
1414
// to N/2-1), modeord=1: FFT-compatible mode ordering in fk (from 0 to N/2-1, then -N/2 up to -1).
1515
template <typename T>
1616
__global__ void deconvolve_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1, int modeord) {
17+
int pivot1, w1, fwkerind1;
18+
T kervalue;
19+
1720
for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < ms; i += blockDim.x * gridDim.x) {
18-
int w1 = ( modeord == 0 ) ? ( (i - ms / 2 >= 0) ? i - ms / 2 : nf1 + i - ms / 2 ) : ( (i - ms + ms / 2 >= 0) ? nf1 + i - ms : i );
21+
if (modeord == 0) {
22+
pivot1 = i - ms / 2;
23+
w1 = (pivot1 >= 0) ? pivot1 : nf1 + pivot1;
24+
fwkerind1 = abs(pivot1);
25+
} else {
26+
pivot1 = i - ms + ms / 2;
27+
w1 = (pivot1 >= 0) ? nf1 + i - ms : i;
28+
fwkerind1 = (pivot1 >= 0) ? ms - i : i;
29+
}
1930

20-
T kervalue = fwkerhalf1[(modeord==0) ? abs(i - ms / 2) : ((i - ms + ms / 2 >= 0) ? ms - i : i)];
31+
kervalue = fwkerhalf1[fwkerind1];
2132
fk[i].x = fw[w1].x / kervalue;
2233
fk[i].y = fw[w1].y / kervalue;
2334
}
@@ -26,15 +37,33 @@ __global__ void deconvolve_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex
2637
template <typename T>
2738
__global__ void deconvolve_2d(int ms, int mt, int nf1, int nf2, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1,
2839
T *fwkerhalf2, int modeord) {
40+
int pivot1, pivot2, w1, w2, fwkerind1, fwkerind2;
41+
int k1, k2, inidx, outidx;
42+
T kervalue;
43+
2944
for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < ms * mt; i += blockDim.x * gridDim.x) {
30-
int k1 = i % ms;
31-
int k2 = i / ms;
32-
int outidx = k1 + k2 * ms;
33-
int w1 = ( modeord == 0 ) ? ( (k1 - ms / 2 >= 0) ? k1 - ms / 2 : nf1 + k1 - ms / 2 ) : ( (k1 - ms + ms / 2 >= 0) ? nf1 + k1 - ms : k1 );
34-
int w2 = ( modeord == 0 ) ? ( (k2 - mt / 2 >= 0) ? k2 - mt / 2 : nf2 + k2 - mt / 2 ) : ( (k2 - mt + mt / 2 >= 0) ? nf2 + k2 - mt : k2 );
35-
int inidx = w1 + w2 * nf1;
36-
37-
T kervalue = fwkerhalf1[(modeord==0) ? abs(k1 - ms / 2) : ((k1 - ms + ms / 2 >= 0) ? ms - k1 : k1)] * fwkerhalf2[(modeord==0) ? abs(k2 - mt / 2) : ((k2 - mt + mt / 2 >= 0) ? mt - k2 : k2)];
45+
k1 = i % ms;
46+
k2 = i / ms;
47+
outidx = k1 + k2 * ms;
48+
49+
if (modeord == 0) {
50+
pivot1 = k1 - ms / 2;
51+
pivot2 = k2 - mt / 2;
52+
w1 = (pivot1 >= 0) ? pivot1 : nf1 + pivot1;
53+
w2 = (pivot2 >= 0) ? pivot2 : nf2 + pivot2;
54+
fwkerind1 = abs(pivot1);
55+
fwkerind2 = abs(pivot2);
56+
} else {
57+
pivot1 = k1 - ms + ms / 2;
58+
pivot2 = k2 - mt + mt / 2;
59+
w1 = (pivot1 >= 0) ? nf1 + k1 - ms : k1;
60+
w2 = (pivot2 >= 0) ? nf2 + k2 - mt : k2;
61+
fwkerind1 = (pivot1 >= 0) ? ms - k1 : k1;
62+
fwkerind2 = (pivot2 >= 0) ? mt - k2 : k2;
63+
}
64+
65+
inidx = w1 + w2 * nf1;
66+
kervalue = fwkerhalf1[fwkerind1] * fwkerhalf2[fwkerind2];
3867
fk[outidx].x = fw[inidx].x / kervalue;
3968
fk[outidx].y = fw[inidx].y / kervalue;
4069
}
@@ -43,17 +72,40 @@ __global__ void deconvolve_2d(int ms, int mt, int nf1, int nf2, cuda_complex<T>
4372
template <typename T>
4473
__global__ void deconvolve_3d(int ms, int mt, int mu, int nf1, int nf2, int nf3, cuda_complex<T> *fw,
4574
cuda_complex<T> *fk, T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3, int modeord) {
75+
int pivot1, pivot2, pivot3, w1, w2, w3, fwkerind1, fwkerind2, fwkerind3;
76+
int k1, k2, k3, inidx, outidx;
77+
T kervalue;
78+
4679
for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < ms * mt * mu; i += blockDim.x * gridDim.x) {
47-
int k1 = i % ms;
48-
int k2 = (i / ms) % mt;
49-
int k3 = (i / ms / mt);
50-
int outidx = k1 + k2 * ms + k3 * ms * mt;
51-
int w1 = ( modeord == 0 ) ? ( (k1 - ms / 2 >= 0) ? k1 - ms / 2 : nf1 + k1 - ms / 2 ) : ( (k1 - ms + ms / 2 >= 0) ? nf1 + k1 - ms : k1 );
52-
int w2 = ( modeord == 0 ) ? ( (k2 - mt / 2 >= 0) ? k2 - mt / 2 : nf2 + k2 - mt / 2 ) : ( (k2 - mt + mt / 2 >= 0) ? nf2 + k2 - mt : k2 );
53-
int w3 = ( modeord == 0 ) ? ( (k3 - mu / 2 >= 0) ? k3 - mu / 2 : nf3 + k3 - mu / 2 ) : ( (k3 - mu + mu / 2 >= 0) ? nf3 + k3 - mu : k3 );
54-
int inidx = w1 + w2 * nf1 + w3 * nf1 * nf2;
55-
56-
T kervalue = fwkerhalf1[(modeord==0) ? abs(k1 - ms / 2) : ((k1 - ms + ms / 2 >= 0) ? ms - k1 : k1)] * fwkerhalf2[(modeord==0) ? abs(k2 - mt / 2) : ((k2 - mt + mt / 2 >= 0) ? mt - k2 : k2)] * fwkerhalf3[(modeord==0) ? abs(k3 - mu / 2) : ((k3 - mu + mu / 2 >= 0) ? mu - k3 : k3)];
80+
k1 = i % ms;
81+
k2 = (i / ms) % mt;
82+
k3 = (i / ms / mt);
83+
outidx = k1 + k2 * ms + k3 * ms * mt;
84+
85+
if (modeord == 0) {
86+
pivot1 = k1 - ms / 2;
87+
pivot2 = k2 - mt / 2;
88+
pivot3 = k3 - mu / 2;
89+
w1 = (pivot1 >= 0) ? pivot1 : nf1 + pivot1;
90+
w2 = (pivot2 >= 0) ? pivot2 : nf2 + pivot2;
91+
w3 = (pivot3 >= 0) ? pivot3 : nf3 + pivot3;
92+
fwkerind1 = abs(pivot1);
93+
fwkerind2 = abs(pivot2);
94+
fwkerind3 = abs(pivot3);
95+
} else {
96+
pivot1 = k1 - ms + ms / 2;
97+
pivot2 = k2 - mt + mt / 2;
98+
pivot3 = k3 - mu + mu / 2;
99+
w1 = (pivot1 >= 0) ? nf1 + k1 - ms : k1;
100+
w2 = (pivot2 >= 0) ? nf2 + k2 - mt : k2;
101+
w3 = (pivot3 >= 0) ? nf3 + k3 - mu : k3;
102+
fwkerind1 = (pivot1 >= 0) ? ms - k1 : k1;
103+
fwkerind2 = (pivot2 >= 0) ? mt - k2 : k2;
104+
fwkerind3 = (pivot3 >= 0) ? mu - k3 : k3;
105+
}
106+
107+
inidx = w1 + w2 * nf1 + w3 * nf1 * nf2;
108+
kervalue = fwkerhalf1[fwkerind1] * fwkerhalf2[fwkerind2] * fwkerhalf3[fwkerind3];
57109
fk[outidx].x = fw[inidx].x / kervalue;
58110
fk[outidx].y = fw[inidx].y / kervalue;
59111
}
@@ -62,10 +114,21 @@ __global__ void deconvolve_3d(int ms, int mt, int mu, int nf1, int nf2, int nf3,
62114
/* Kernel for copying fk to fw with same amplication */
63115
template <typename T>
64116
__global__ void amplify_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1, int modeord) {
117+
int pivot1, w1, fwkerind1;
118+
T kervalue;
119+
65120
for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < ms; i += blockDim.x * gridDim.x) {
66-
int w1 = ( modeord == 0 ) ? ( (i - ms / 2 >= 0) ? i - ms / 2 : nf1 + i - ms / 2 ) : ( (i - ms + ms / 2 >= 0) ? nf1 + i - ms : i );
121+
if (modeord == 0) {
122+
pivot1 = i - ms / 2;
123+
w1 = (pivot1 >= 0) ? pivot1 : nf1 + pivot1;
124+
fwkerind1 = abs(pivot1);
125+
} else {
126+
pivot1 = i - ms + ms / 2;
127+
w1 = (pivot1 >= 0) ? nf1 + i - ms : i;
128+
fwkerind1 = (pivot1 >= 0) ? ms - i : i;
129+
}
67130

68-
T kervalue = fwkerhalf1[(modeord==0) ? abs(i - ms / 2) : ((i - ms + ms / 2 >= 0) ? ms - i : i)];
131+
kervalue = fwkerhalf1[fwkerind1];
69132
fw[w1].x = fk[i].x / kervalue;
70133
fw[w1].y = fk[i].y / kervalue;
71134
}
@@ -74,15 +137,33 @@ __global__ void amplify_1d(int ms, int nf1, cuda_complex<T> *fw, cuda_complex<T>
74137
template <typename T>
75138
__global__ void amplify_2d(int ms, int mt, int nf1, int nf2, cuda_complex<T> *fw, cuda_complex<T> *fk, T *fwkerhalf1,
76139
T *fwkerhalf2, int modeord) {
140+
int pivot1, pivot2, w1, w2, fwkerind1, fwkerind2;
141+
int k1, k2, inidx, outidx;
142+
T kervalue;
143+
77144
for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < ms * mt; i += blockDim.x * gridDim.x) {
78-
int k1 = i % ms;
79-
int k2 = i / ms;
80-
int inidx = k1 + k2 * ms;
81-
int w1 = ( modeord == 0 ) ? ( (k1 - ms / 2 >= 0) ? k1 - ms / 2 : nf1 + k1 - ms / 2 ) : ( (k1 - ms + ms / 2 >= 0) ? nf1 + k1 - ms : k1 );
82-
int w2 = ( modeord == 0 ) ? ( (k2 - mt / 2 >= 0) ? k2 - mt / 2 : nf2 + k2 - mt / 2 ) : ( (k2 - mt + mt / 2 >= 0) ? nf2 + k2 - mt : k2 );
83-
int outidx = w1 + w2 * nf1;
84-
85-
T kervalue = fwkerhalf1[(modeord==0) ? abs(k1 - ms / 2) : ((k1 - ms + ms / 2 >= 0) ? ms - k1 : k1)] * fwkerhalf2[(modeord==0) ? abs(k2 - mt / 2) : ((k2 - mt + mt / 2 >= 0) ? mt - k2 : k2)];
145+
k1 = i % ms;
146+
k2 = i / ms;
147+
inidx = k1 + k2 * ms;
148+
149+
if (modeord == 0) {
150+
pivot1 = k1 - ms / 2;
151+
pivot2 = k2 - mt / 2;
152+
w1 = (pivot1 >= 0) ? pivot1 : nf1 + pivot1;
153+
w2 = (pivot2 >= 0) ? pivot2 : nf2 + pivot2;
154+
fwkerind1 = abs(pivot1);
155+
fwkerind2 = abs(pivot2);
156+
} else {
157+
pivot1 = k1 - ms + ms / 2;
158+
pivot2 = k2 - mt + mt / 2;
159+
w1 = (pivot1 >= 0) ? nf1 + k1 - ms : k1;
160+
w2 = (pivot2 >= 0) ? nf2 + k2 - mt : k2;
161+
fwkerind1 = (pivot1 >= 0) ? ms - k1 : k1;
162+
fwkerind2 = (pivot2 >= 0) ? mt - k2 : k2;
163+
}
164+
165+
outidx = w1 + w2 * nf1;
166+
kervalue = fwkerhalf1[fwkerind1] * fwkerhalf2[fwkerind2];
86167
fw[outidx].x = fk[inidx].x / kervalue;
87168
fw[outidx].y = fk[inidx].y / kervalue;
88169
}
@@ -91,17 +172,40 @@ __global__ void amplify_2d(int ms, int mt, int nf1, int nf2, cuda_complex<T> *fw
91172
template <typename T>
92173
__global__ void amplify_3d(int ms, int mt, int mu, int nf1, int nf2, int nf3, cuda_complex<T> *fw, cuda_complex<T> *fk,
93174
T *fwkerhalf1, T *fwkerhalf2, T *fwkerhalf3, int modeord) {
175+
int pivot1, pivot2, pivot3, w1, w2, w3, fwkerind1, fwkerind2, fwkerind3;
176+
int k1, k2, k3, inidx, outidx;
177+
T kervalue;
178+
94179
for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < ms * mt * mu; i += blockDim.x * gridDim.x) {
95-
int k1 = i % ms;
96-
int k2 = (i / ms) % mt;
97-
int k3 = (i / ms / mt);
98-
int inidx = k1 + k2 * ms + k3 * ms * mt;
99-
int w1 = ( modeord == 0 ) ? ( (k1 - ms / 2 >= 0) ? k1 - ms / 2 : nf1 + k1 - ms / 2 ) : ( (k1 - ms + ms / 2 >= 0) ? nf1 + k1 - ms : k1 );
100-
int w2 = ( modeord == 0 ) ? ( (k2 - mt / 2 >= 0) ? k2 - mt / 2 : nf2 + k2 - mt / 2 ) : ( (k2 - mt + mt / 2 >= 0) ? nf2 + k2 - mt : k2 );
101-
int w3 = ( modeord == 0 ) ? ( (k3 - mu / 2 >= 0) ? k3 - mu / 2 : nf3 + k3 - mu / 2 ) : ( (k3 - mu + mu / 2 >= 0) ? nf3 + k3 - mu : k3 );
102-
int outidx = w1 + w2 * nf1 + w3 * nf1 * nf2;
103-
104-
T kervalue = fwkerhalf1[(modeord==0) ? abs(k1 - ms / 2) : ((k1 - ms + ms / 2 >= 0) ? ms - k1 : k1)] * fwkerhalf2[(modeord==0) ? abs(k2 - mt / 2) : ((k2 - mt + mt / 2 >= 0) ? mt - k2 : k2)] * fwkerhalf3[(modeord==0) ? abs(k3 - mu / 2) : ((k3 - mu + mu / 2 >= 0) ? mu - k3 : k3)];
180+
k1 = i % ms;
181+
k2 = (i / ms) % mt;
182+
k3 = (i / ms / mt);
183+
inidx = k1 + k2 * ms + k3 * ms * mt;
184+
185+
if (modeord == 0) {
186+
pivot1 = k1 - ms / 2;
187+
pivot2 = k2 - mt / 2;
188+
pivot3 = k3 - mu / 2;
189+
w1 = (pivot1 >= 0) ? pivot1 : nf1 + pivot1;
190+
w2 = (pivot2 >= 0) ? pivot2 : nf2 + pivot2;
191+
w3 = (pivot3 >= 0) ? pivot3 : nf3 + pivot3;
192+
fwkerind1 = abs(pivot1);
193+
fwkerind2 = abs(pivot2);
194+
fwkerind3 = abs(pivot3);
195+
} else {
196+
pivot1 = k1 - ms + ms / 2;
197+
pivot2 = k2 - mt + mt / 2;
198+
pivot3 = k3 - mu + mu / 2;
199+
w1 = (pivot1 >= 0) ? nf1 + k1 - ms : k1;
200+
w2 = (pivot2 >= 0) ? nf2 + k2 - mt : k2;
201+
w3 = (pivot3 >= 0) ? nf3 + k3 - mu : k3;
202+
fwkerind1 = (pivot1 >= 0) ? ms - k1 : k1;
203+
fwkerind2 = (pivot2 >= 0) ? mt - k2 : k2;
204+
fwkerind3 = (pivot3 >= 0) ? mu - k3 : k3;
205+
}
206+
207+
outidx = w1 + w2 * nf1 + w3 * nf1 * nf2;
208+
kervalue = fwkerhalf1[fwkerind1] * fwkerhalf2[fwkerind2] * fwkerhalf3[fwkerind3];
105209
fw[outidx].x = fk[inidx].x / kervalue;
106210
fw[outidx].y = fk[inidx].y / kervalue;
107211
}

0 commit comments

Comments
 (0)