2828
2929from unittest .mock import Mock
3030
31- from httomolibgpu .misc .morph import data_resampler
32-
3331if cupy_run :
3432 from httomolibgpu .cuda_kernels import load_cuda_module
3533 from cupyx .scipy .ndimage import shift , gaussian_filter
@@ -146,29 +144,21 @@ def find_center_vo(
146144 _sino_cs = gaussian_filter (_sino , (3 , 1 ), mode = "reflect" )
147145 _sino_fs = gaussian_filter (_sino , (2 , 2 ), mode = "reflect" )
148146
149- # Downsampling here by averaging along a chosen dimension
150- # NOTE: the gpu implementation of _downsample is erroneuos, needs to be re-written
147+ # Downsampling by averaging along a chosen dimension
148+ if dsp_angle > 1 or dsp_detX > 1 :
149+ _sino_cs = _downsample (_sino_cs , dsp_angle , dsp_detX )
150+
151+ # NOTE: the gpu implementation of _downsample kernel bellow is erroneuos (different results with each run), needs to be re-written
151152 # if dsp_angle > 1:
152- # _sino_cs = _downsample (_sino_cs, level=dsp_angle, axis=0)
153+ # _sino_cs = _downsample_kernel (_sino_cs, level=dsp_angle, axis=0)
153154 # if dsp_detX > 1:
154- # _sino_cs = _downsample (_sino_cs, level=dsp_detX, axis=1)
155+ # _sino_cs = _downsample_kernel (_sino_cs, level=dsp_detX, axis=1)
155156
156- if dsp_angle > 1 or dsp_detX > 1 :
157- # NOTE: this can downsample the data but it is not what is implemented in the original code, so result is slightly different
158- # dsp_angle_size = _sino_cs.shape[0] // dsp_angle
159- # dsp_detX_size = _sino_cs.shape[1] // dsp_detX
160- # _sino_cs = data_resampler(
161- # _sino_cs, [dsp_angle_size, dsp_detX_size], axis=1, interpolation="linear"
162- # )
163- _sino_cs_numpy = _downsample_numpy (
164- _sino_cs .get (), dsp_angle , dsp_detX
165- ) # this is the original CPU implementation
166- _sino_cs = cp .asarray (_sino_cs_numpy , dtype = cp .float32 )
167-
168- # NOTE: this is correct when we do not run any CUDA kernels, hence the performance is suboptimal
157+ # NOTE: this is correct implementation that avoids running any CUDA kernels. The performance is suboptimal
169158 init_cen = _search_coarse (_sino_cs , start_cor , stop_cor , ratio , drop )
170159
171- # NOTE: a different to the expected result, not investigated why
160+ # NOTE: similar to the coarse module above, this is currently a correct function
161+ # but it is NOT using CUDA kernels written. Therefore some kernels re-writing is needed.
172162 fine_cen = _search_fine (
173163 _sino_fs , fine_srange , step , float (init_cen ) * dsp_detX + off_set , ratio , drop
174164 )
@@ -204,7 +194,7 @@ def _search_coarse(sino, smin, smax, ratio, drop):
204194 list_shift = 2.0 * (list_cor - cen_fliplr )
205195 list_metric = cp .empty (list_shift .shape , dtype = cp .float32 )
206196
207- # NOTE: this gives a different result to the CPU code, also works with half data and half mask
197+ # NOTE: this gives a different result to the CPU code, also works with a half data and a half mask
208198 # _calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, list_metric)
209199
210200 # This essentially repeats the CPU code... probably not optimal but correct
@@ -234,17 +224,32 @@ def _search_fine(sino, srad, step, init_cen, ratio, drop):
234224
235225 flip_sino = cp .ascontiguousarray (cp .fliplr (sino ))
236226 comp_sino = cp .ascontiguousarray (cp .flipud (sino ))
237- mask = _create_mask (2 * nrow , ncol , 0.5 * ratio * ncol , drop )
227+ mask = _create_mask_numpy (2 * nrow , ncol , 0.5 * ratio * ncol , drop )
228+ mask = cp .asarray (mask , dtype = cp .float32 )
238229
239230 cen_fliplr = (ncol - 1.0 ) / 2.0
240- srad = max (min (abs (float (srad )), ncol / 4.0 ), 1.0 )
241- step = max (min (abs (step ), srad ), 0.1 )
242- init_cen = max (min (init_cen , ncol - srad - 1 ), srad )
243- list_cor = init_cen + cp .arange (- srad , srad + step , step , dtype = np .float32 )
231+ # NOTE: those are different to new implementation
232+ # srad = max(min(abs(float(srad)), ncol / 4.0), 1.0)
233+ # step = max(min(abs(step), srad), 0.1)
234+ srad = np .clip (np .abs (srad ), 1 , ncol // 10 - 1 )
235+ step = np .clip (np .abs (step ), 0.1 , 1.1 )
236+ init_cen = np .clip (init_cen , srad , ncol - srad - 1 )
237+ list_cor = init_cen + cp .arange (- srad , srad + step , step , dtype = cp .float32 )
244238 list_shift = 2.0 * (list_cor - cen_fliplr )
245239 list_metric = cp .empty (list_shift .shape , dtype = "float32" )
246240
247- _calculate_metric (list_shift , sino , flip_sino , comp_sino , mask , out = list_metric )
241+ for i , shift_l in enumerate (list_shift ):
242+ sino_shift = shift (flip_sino , (0 , shift_l ), order = 3 , prefilter = True )
243+ if shift_l >= 0 :
244+ shift_int = int (cp .ceil (shift_l ))
245+ sino_shift [:, :shift_int ] = comp_sino [:, :shift_int ]
246+ else :
247+ shift_int = int (cp .floor (shift_l ))
248+ sino_shift [:, shift_int :] = comp_sino [:, shift_int :]
249+ mat1 = cp .vstack ((sino , sino_shift ))
250+ list_metric [i ] = cp .mean (cp .abs (fftshift (fft2 (mat1 ))) * mask )
251+
252+ # _calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, out=list_metric)
248253 cor = list_cor [cp .argmin (list_metric )]
249254 return cor
250255
@@ -422,7 +427,7 @@ def _calculate_metric(list_shift, sino1, sino2, sino3, mask, out):
422427 )
423428
424429
425- def _downsample_numpy (image , dsp_fact0 , dsp_fact1 ):
430+ def _downsample (image , dsp_fact0 , dsp_fact1 ):
426431 """Downsample an image by averaging.
427432
428433 Parameters
@@ -436,8 +441,8 @@ def _downsample_numpy(image, dsp_fact0, dsp_fact1):
436441 image_dsp : Downsampled image.
437442 """
438443 (height , width ) = image .shape
439- dsp_fact0 = np .clip (np .int16 (dsp_fact0 ), 1 , height // 2 )
440- dsp_fact1 = np .clip (np .int16 (dsp_fact1 ), 1 , width // 2 )
444+ dsp_fact0 = cp .clip (cp .int16 (dsp_fact0 ), 1 , height // 2 )
445+ dsp_fact1 = cp .clip (cp .int16 (dsp_fact1 ), 1 , width // 2 )
441446 height_dsp = height // dsp_fact0
442447 width_dsp = width // dsp_fact1
443448 if dsp_fact0 == 1 and dsp_fact1 == 1 :
@@ -452,7 +457,7 @@ def _downsample_numpy(image, dsp_fact0, dsp_fact1):
452457 return image_dsp
453458
454459
455- def _downsample (sino , level , axis ):
460+ def _downsample_kernel (sino , level , axis ):
456461 assert sino .dtype == cp .float32 , "single precision floating point input required"
457462 assert sino .flags ["C_CONTIGUOUS" ], "list_shift must be C-contiguous"
458463
0 commit comments