@@ -229,31 +229,16 @@ def _search_fine(sino, srad, step, init_cen, ratio, drop):
229229 flip_sino = cp .ascontiguousarray (cp .fliplr (sino ))
230230 comp_sino = cp .ascontiguousarray (cp .flipud (sino ))
231231 mask = _create_mask (2 * nrow , ncol , 0.5 * ratio * ncol , drop )
232- # mask = cp.asarray(mask, dtype=cp.float32)
233232
234233 cen_fliplr = (ncol - 1.0 ) / 2.0
235- # NOTE: those are different to new implementation
236- # srad = max(min(abs(float(srad)), ncol / 4.0), 1.0)
237- # step = max(min(abs(step), srad), 0.1)
238234 srad = np .clip (np .abs (srad ), 1 , ncol // 10 - 1 )
239235 step = np .clip (np .abs (step ), 0.1 , 1.1 )
240236 init_cen = np .clip (init_cen , srad , ncol - srad - 1 )
241237 list_cor = init_cen + cp .arange (- srad , srad + step , step , dtype = cp .float32 )
242238 list_shift = 2.0 * (list_cor - cen_fliplr )
243239 list_metric = cp .empty (list_shift .shape , dtype = "float32" )
244240
245- for i , shift_l in enumerate (list_shift ):
246- sino_shift = shift (flip_sino , (0 , shift_l ), order = 3 , prefilter = True )
247- if shift_l >= 0 :
248- shift_int = int (cp .ceil (shift_l ))
249- sino_shift [:, :shift_int ] = comp_sino [:, :shift_int ]
250- else :
251- shift_int = int (cp .floor (shift_l ))
252- sino_shift [:, shift_int :] = comp_sino [:, shift_int :]
253- mat1 = cp .vstack ((sino , sino_shift ))
254- list_metric [i ] = cp .mean (cp .abs (fftshift (fft2 (mat1 ))) * mask )
255-
256- # _calculate_metric(list_shift, sino, flip_sino, comp_sino, mask, out=list_metric)
241+ _calculate_metric (list_shift , sino , flip_sino , comp_sino , mask , out = list_metric )
257242 cor = list_cor [cp .argmin (list_metric )]
258243 return cor
259244
@@ -273,6 +258,35 @@ def _create_mask_numpy(nrow, ncol, radius, drop):
273258 mask [:, cen_col - 1 : cen_col + 2 ] = 0.0
274259 return mask
275260
261+ def _create_mask_half (nrow , ncol , radius , drop ):
262+ du = 1.0 / ncol
263+ dv = (nrow - 1.0 ) / (nrow * 2.0 * np .pi )
264+ cen_row = int (math .ceil (nrow / 2.0 ) - 1 )
265+ cen_col = int (math .ceil (ncol / 2.0 ) - 1 )
266+ drop = min ([drop , int (math .ceil (0.05 * nrow ))])
267+
268+ block_x = 128
269+ block_y = 1
270+ block_dims = (block_x , block_y )
271+ grid_x = (ncol // 2 + 1 + block_x - 1 ) // block_x
272+ grid_y = nrow
273+ grid_dims = (grid_x , grid_y )
274+ mask = cp .empty ((nrow , ncol // 2 + 1 ), dtype = "uint16" )
275+ params = (
276+ ncol ,
277+ nrow ,
278+ cen_col ,
279+ cen_row ,
280+ cp .float32 (du ),
281+ cp .float32 (dv ),
282+ cp .float32 (radius ),
283+ cp .float32 (drop ),
284+ mask ,
285+ )
286+ module = load_cuda_module ("generate_mask" )
287+ kernel = module .get_function ("generate_mask" )
288+ kernel (grid_dims , block_dims , params )
289+ return mask
276290
277291def _create_mask (nrow , ncol , radius , drop ):
278292 du = 1.0 / ncol
@@ -300,7 +314,7 @@ def _create_mask(nrow, ncol, radius, drop):
300314 mask ,
301315 )
302316 module = load_cuda_module ("generate_mask" )
303- kernel = module .get_function ("generate_mask_new " )
317+ kernel = module .get_function ("generate_mask_full " )
304318 kernel (grid_dims , block_dims , params )
305319 return mask
306320
@@ -344,28 +358,30 @@ def _calculate_chunks(
344358 return stop_idx
345359
346360
347- def _calculate_metric (list_shift , sino1 , sino2 , sino3 , mask , out ):
361+ def _calculate_metric (list_shift , sino , flip_sino , comp_sino , mask , out ):
348362 # this tries to simplify - if shift_col is integer, no need to spline interpolate
349363 assert list_shift .dtype == cp .float32 , "shifts must be single precision floats"
350- assert sino1 .dtype == cp .float32 , "sino1 must be float32"
351- assert sino2 .dtype == cp .float32 , "sino1 must be float32"
352- assert sino3 .dtype == cp .float32 , "sino1 must be float32"
353- assert out .dtype == cp .float32 , "sino1 must be float32"
354- assert sino2 .flags ["C_CONTIGUOUS" ], "sino2 must be C-contiguous"
355- assert sino3 .flags ["C_CONTIGUOUS" ], "sino3 must be C-contiguous"
364+ assert sino .dtype == cp .float32 , "sino must be float32"
365+ assert flip_sino .dtype == cp .float32 , "flip_sino must be float32"
366+ assert comp_sino .dtype == cp .float32 , "comp_sino must be float32"
367+ assert out .dtype == cp .float32 , "out must be float32"
368+ assert flip_sino .flags ["C_CONTIGUOUS" ], "flip_sino must be C-contiguous"
369+ assert comp_sino .flags ["C_CONTIGUOUS" ], "comp_sino must be C-contiguous"
356370 assert list_shift .flags ["C_CONTIGUOUS" ], "list_shift must be C-contiguous"
357371 nshifts = list_shift .shape [0 ]
358- na1 = sino1 .shape [0 ]
359- na2 = sino2 .shape [0 ]
372+ na1 = sino .shape [0 ]
373+ na2 = flip_sino .shape [0 ]
360374
361375 module = load_cuda_module ("center_360_shifts" )
362376 shift_whole_shifts = module .get_function ("shift_whole_shifts" )
363377 # note: we don't have to calculate the mean here, as we're only looking for minimum metric.
364378 # The sum is enough.
365379 masked_sum_abs_kernel = cp .ReductionKernel (
366- in_params = "complex64 x, uint16 mask" , # input, complex + mask
380+ in_params = "complex64 x, float32 mask" , # input, complex + mask
381+ # in_params="complex64 x, uint16 mask", # input, complex + mask
367382 out_params = "float32 out" , # output, real
368- map_expr = "mask ? abs(x) : 0.0f" ,
383+ map_expr = "abs(x) * mask" ,
384+ # map_expr="mask ? abs(x) : 0.0f",
369385 reduce_expr = "a + b" ,
370386 post_map_expr = "out = a" ,
371387 identity = "0.0f" ,
@@ -376,13 +392,14 @@ def _calculate_metric(list_shift, sino1, sino2, sino3, mask, out):
376392 # determine how many shifts we can fit in the available memory
377393 # and iterate in chunks
378394 chunks = _calculate_chunks (
379- nshifts , (na1 + na2 ) * sino2 .shape [1 ] * cp .float32 ().nbytes
395+ nshifts , (na1 + na2 ) * flip_sino .shape [1 ] * cp .float32 ().nbytes
380396 )
381397
382- mat = cp .empty ((chunks [0 ], na1 + na2 , sino2 .shape [1 ]), dtype = cp .float32 )
383- mat [:, :na1 , :] = sino1
398+ mat = cp .empty ((chunks [0 ], na1 + na2 , flip_sino .shape [1 ]), dtype = cp .float32 )
399+ mat [:, :na1 , :] = sino
400+
384401 # explicitly create FFT plan here, so it's not cached and clearly re-used
385- plan = get_fft_plan (mat , mat .shape [- 2 :], axes = (1 , 2 ), value_type = "R2C " )
402+ plan = get_fft_plan (mat , mat .shape [- 2 :], axes = (1 , 2 ), value_type = "C2C " )
386403
387404 for i , stop_idx in enumerate (chunks ):
388405 if i > 0 :
@@ -394,18 +411,18 @@ def _calculate_metric(list_shift, sino1, sino2, sino3, mask, out):
394411 size = stop_idx - start_idx
395412
396413 # first, handle the integer shifts without spline in a raw kernel,
397- # and shift in the sino3 one accordingly
414+ # and shift in the comp_sino one accordingly
398415 bx = 128
399- gx = (sino3 .shape [1 ] + bx - 1 ) // bx
416+ gx = (comp_sino .shape [1 ] + bx - 1 ) // bx
400417 shift_whole_shifts (
401418 grid = (gx , na2 , size ), ####
402419 block = (bx , 1 , 1 ),
403420 args = (
404- sino2 ,
405- sino3 ,
421+ flip_sino ,
422+ comp_sino ,
406423 list_shift [start_idx :stop_idx ],
407424 mat [:, na1 :, :],
408- sino3 .shape [1 ],
425+ comp_sino .shape [1 ],
409426 na1 + na2 ,
410427 ),
411428 )
@@ -415,7 +432,7 @@ def _calculate_metric(list_shift, sino1, sino2, sino3, mask, out):
415432 for i in range (list_shift_host .shape [0 ]):
416433 shift_col = float (list_shift_host [i ])
417434 if not shift_col .is_integer ():
418- shifted = shift (sino2 , (0 , shift_col ), order = 3 , prefilter = True )
435+ shifted = shift (flip_sino , (0 , shift_col ), order = 3 , prefilter = True )
419436 shift_int = round_up (shift_col )
420437 if shift_int >= 0 :
421438 mat [i , na1 :, shift_int :] = shifted [:, shift_int :]
@@ -425,7 +442,9 @@ def _calculate_metric(list_shift, sino1, sino2, sino3, mask, out):
425442 # stack and transform
426443 # (we do the full sized mat FFT, even though the last chunk may be smaller, to
427444 # make sure we can re-use the same FFT plan as before)
428- mat_freq = rfft2 (mat , axes = (1 , 2 ), norm = None , plan = plan )
445+ # mat_freq = fft2(mat, axes=(1, 2), norm=None, plan=plan)
446+ mat_freq = fftshift (fft2 (mat , axes = (1 , 2 ), norm = None , plan = plan ), axes = (1 , 2 ))
447+
429448 masked_sum_abs_kernel (
430449 mat_freq [:size , :, :], mask , out = out [start_idx :stop_idx ], axis = (1 , 2 )
431450 )
0 commit comments