3838 fftshift = Mock ()
3939
4040from numpy import float32
41- from typing import Union
41+ from typing import Tuple
4242import math
4343
4444__all__ = [
@@ -348,25 +348,17 @@ def _shift_bit_length(x: int) -> int:
348348 return 1 << (x - 1 ).bit_length ()
349349
350350
351- def _pad_projections_to_second_power (tomo : cp .ndarray ) -> Union [cp .ndarray , tuple ]:
352- """
353- Performs padding of each projection to the next power of 2.
354- If the shape is not even we also care of that before padding.
351+ def _calculate_pad_size (datashape : tuple ) -> list :
352+ """Calculating the padding size
355353
356- Parameters
357- ----------
358- tomo : cp.ndarray
359- 3d projection data
354+ Args:
355+ datashape (tuple): the shape of the 3D data
360356
361- Returns
362- -------
363- ndarray: padded 3d projection data
364- tuple: a tuple with padding dimensions
357+ Returns:
358+ list: the padded dimensions
365359 """
366- full_shape_tomo = cp .shape (tomo )
367-
368- pad_tup = []
369- for index , element in enumerate (full_shape_tomo ):
360+ pad_list = []
361+ for index , element in enumerate (datashape ):
370362 if index == 0 :
371363 pad_width = (0 , 0 ) # do not pad the slicing dim
372364 else :
@@ -380,11 +372,36 @@ def _pad_projections_to_second_power(tomo: cp.ndarray) -> Union[cp.ndarray, tupl
380372 right_pad = diff - left_pad
381373 pad_width = (left_pad , right_pad )
382374
383- pad_tup .append (pad_width )
375+ pad_list .append (pad_width )
376+
377+ return pad_list
378+
379+
380+ def _pad_projections_to_second_power (
381+ tomo : cp .ndarray ,
382+ ) -> Tuple [cp .ndarray , Tuple [int , int ]]:
383+ """
384+ Performs padding of each projection to the next power of 2.
385+ If the shape is not even we also care of that before padding.
386+
387+ Parameters
388+ ----------
389+ tomo : cp.ndarray
390+ 3d projection data
391+
392+ Returns
393+ -------
394+ Tuple consisting of:
395+ ndarray: padded 3d projection data
396+ tuple: a tuple with padding dimensions
397+ """
398+ full_shape_tomo = cp .shape (tomo )
399+
400+ pad_list = _calculate_pad_size (full_shape_tomo )
384401
385- padded_tomo = cp .pad (tomo , tuple (pad_tup ), "edge" )
402+ padded_tomo = cp .pad (tomo , tuple (pad_list ), "edge" )
386403
387- return padded_tomo , pad_tup
404+ return padded_tomo , tuple ( pad_list )
388405
389406
390407def _paganin_filter_factor2 (energy , dist , alpha , w2 ):
0 commit comments