Skip to content

Commit 0338011

Browse files
committed
VoCentering GPU suboptimal but working code
1 parent d50f566 commit 0338011

File tree

1 file changed

+36
-31
lines changed

1 file changed

+36
-31
lines changed

httomolibgpu/recon/rotation.py

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828

2929
from unittest.mock import Mock
3030

31-
from httomolibgpu.misc.morph import data_resampler
32-
3331
if cupy_run:
3432
from httomolibgpu.cuda_kernels import load_cuda_module
3533
from cupyx.scipy.ndimage import shift, gaussian_filter
@@ -146,29 +144,21 @@ def find_center_vo(
146144
_sino_cs = gaussian_filter(_sino, (3, 1), mode="reflect")
147145
_sino_fs = gaussian_filter(_sino, (2, 2), mode="reflect")
148146

149-
# Downsampling here by averaging along a chosen dimension
150-
# NOTE: the gpu implementation of _downsample is erroneuos, needs to be re-written
147+
# Downsampling by averaging along a chosen dimension
148+
if dsp_angle > 1 or dsp_detX > 1:
149+
_sino_cs = _downsample(_sino_cs, dsp_angle, dsp_detX)
150+
151+
# NOTE: the gpu implementation of _downsample kernel bellow is erroneuos (different results with each run), needs to be re-written
151152
# if dsp_angle > 1:
152-
# _sino_cs = _downsample(_sino_cs, level=dsp_angle, axis=0)
153+
# _sino_cs = _downsample_kernel(_sino_cs, level=dsp_angle, axis=0)
153154
# if dsp_detX > 1:
154-
# _sino_cs = _downsample(_sino_cs, level=dsp_detX, axis=1)
155+
# _sino_cs = _downsample_kernel(_sino_cs, level=dsp_detX, axis=1)
155156

156-
if dsp_angle > 1 or dsp_detX > 1:
157-
# NOTE: this can downsample the data but it is not what is implemented in the original code, so result is slightly different
158-
# dsp_angle_size = _sino_cs.shape[0] // dsp_angle
159-
# dsp_detX_size = _sino_cs.shape[1] // dsp_detX
160-
# _sino_cs = data_resampler(
161-
# _sino_cs, [dsp_angle_size, dsp_detX_size], axis=1, interpolation="linear"
162-
# )
163-
_sino_cs_numpy = _downsample_numpy(
164-
_sino_cs.get(), dsp_angle, dsp_detX
165-
) # this is the original CPU implementation
166-
_sino_cs = cp.asarray(_sino_cs_numpy, dtype=cp.float32)
167-
168-
# NOTE: this is correct when we do not run any CUDA kernels, hence the performance is suboptimal
157+
# NOTE: this is correct implementation that avoids running any CUDA kernels. The performance is suboptimal
169158
init_cen = _search_coarse(_sino_cs, start_cor, stop_cor, ratio, drop)
170159

171-
# NOTE: a different to the expected result, not investigated why
160+
# NOTE: similar to the coarse module above, this is currently a correct function
161+
# but it is NOT using CUDA kernels written. Therefore some kernels re-writing is needed.
172162
fine_cen = _search_fine(
173163
_sino_fs, fine_srange, step, float(init_cen) * dsp_detX + off_set, ratio, drop
174164
)
@@ -204,7 +194,7 @@ def _search_coarse(sino, smin, smax, ratio, drop):
204194
list_shift = 2.0 * (list_cor - cen_fliplr)
205195
list_metric = cp.empty(list_shift.shape, dtype=cp.float32)
206196

207-
# NOTE: this gives a different result to the CPU code, also works with half data and half mask
197+
# NOTE: this gives a different result to the CPU code, also works with a half data and a half mask
208198
# _calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, list_metric)
209199

210200
# This essentially repeats the CPU code... probably not optimal but correct
@@ -234,17 +224,32 @@ def _search_fine(sino, srad, step, init_cen, ratio, drop):
234224

235225
flip_sino = cp.ascontiguousarray(cp.fliplr(sino))
236226
comp_sino = cp.ascontiguousarray(cp.flipud(sino))
237-
mask = _create_mask(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
227+
mask = _create_mask_numpy(2 * nrow, ncol, 0.5 * ratio * ncol, drop)
228+
mask = cp.asarray(mask, dtype=cp.float32)
238229

239230
cen_fliplr = (ncol - 1.0) / 2.0
240-
srad = max(min(abs(float(srad)), ncol / 4.0), 1.0)
241-
step = max(min(abs(step), srad), 0.1)
242-
init_cen = max(min(init_cen, ncol - srad - 1), srad)
243-
list_cor = init_cen + cp.arange(-srad, srad + step, step, dtype=np.float32)
231+
# NOTE: those are different to new implementation
232+
# srad = max(min(abs(float(srad)), ncol / 4.0), 1.0)
233+
# step = max(min(abs(step), srad), 0.1)
234+
srad = np.clip(np.abs(srad), 1, ncol // 10 - 1)
235+
step = np.clip(np.abs(step), 0.1, 1.1)
236+
init_cen = np.clip(init_cen, srad, ncol - srad - 1)
237+
list_cor = init_cen + cp.arange(-srad, srad + step, step, dtype=cp.float32)
244238
list_shift = 2.0 * (list_cor - cen_fliplr)
245239
list_metric = cp.empty(list_shift.shape, dtype="float32")
246240

247-
_calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, out=list_metric)
241+
for i, shift_l in enumerate(list_shift):
242+
sino_shift = shift(flip_sino, (0, shift_l), order=3, prefilter=True)
243+
if shift_l >= 0:
244+
shift_int = int(cp.ceil(shift_l))
245+
sino_shift[:, :shift_int] = comp_sino[:, :shift_int]
246+
else:
247+
shift_int = int(cp.floor(shift_l))
248+
sino_shift[:, shift_int:] = comp_sino[:, shift_int:]
249+
mat1 = cp.vstack((sino, sino_shift))
250+
list_metric[i] = cp.mean(cp.abs(fftshift(fft2(mat1))) * mask)
251+
252+
# _calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, out=list_metric)
248253
cor = list_cor[cp.argmin(list_metric)]
249254
return cor
250255

@@ -422,7 +427,7 @@ def _calculate_metric(list_shift, sino1, sino2, sino3, mask, out):
422427
)
423428

424429

425-
def _downsample_numpy(image, dsp_fact0, dsp_fact1):
430+
def _downsample(image, dsp_fact0, dsp_fact1):
426431
"""Downsample an image by averaging.
427432
428433
Parameters
@@ -436,8 +441,8 @@ def _downsample_numpy(image, dsp_fact0, dsp_fact1):
436441
image_dsp : Downsampled image.
437442
"""
438443
(height, width) = image.shape
439-
dsp_fact0 = np.clip(np.int16(dsp_fact0), 1, height // 2)
440-
dsp_fact1 = np.clip(np.int16(dsp_fact1), 1, width // 2)
444+
dsp_fact0 = cp.clip(cp.int16(dsp_fact0), 1, height // 2)
445+
dsp_fact1 = cp.clip(cp.int16(dsp_fact1), 1, width // 2)
441446
height_dsp = height // dsp_fact0
442447
width_dsp = width // dsp_fact1
443448
if dsp_fact0 == 1 and dsp_fact1 == 1:
@@ -452,7 +457,7 @@ def _downsample_numpy(image, dsp_fact0, dsp_fact1):
452457
return image_dsp
453458

454459

455-
def _downsample(sino, level, axis):
460+
def _downsample_kernel(sino, level, axis):
456461
assert sino.dtype == cp.float32, "single precision floating point input required"
457462
assert sino.flags["C_CONTIGUOUS"], "list_shift must be C-contiguous"
458463

0 commit comments

Comments
 (0)