Skip to content

Commit c0075e9

Browse files
committed
moving the padding estimator into a separate function
1 parent 64d8e31 commit c0075e9

File tree

1 file changed

+32
-18
lines changed

1 file changed

+32
-18
lines changed

httomolibgpu/prep/phase.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,35 @@ def _shift_bit_length(x: int) -> int:
348348
return 1 << (x - 1).bit_length()
349349

350350

351+
def _calculate_pad_size(datashape: tuple) -> list:
352+
"""Calculating the padding size
353+
354+
Args:
355+
datashape (tuple): the shape of the 3D data
356+
357+
Returns:
358+
list: the padded dimensions
359+
"""
360+
pad_list = []
361+
for index, element in enumerate(datashape):
362+
if index == 0:
363+
pad_width = (0, 0) # do not pad the slicing dim
364+
else:
365+
diff = _shift_bit_length(element + 1) - element
366+
if element % 2 == 0:
367+
pad_width_scalar = diff // 2
368+
pad_width = (pad_width_scalar, pad_width_scalar)
369+
else:
370+
# need an uneven padding for odd-number lengths
371+
left_pad = diff // 2
372+
right_pad = diff - left_pad
373+
pad_width = (left_pad, right_pad)
374+
375+
pad_list.append(pad_width)
376+
377+
return pad_list
378+
379+
351380
def _pad_projections_to_second_power(tomo: cp.ndarray) -> Union[cp.ndarray, tuple]:
352381
"""
353382
Performs padding of each projection to the next power of 2.
@@ -365,26 +394,11 @@ def _pad_projections_to_second_power(tomo: cp.ndarray) -> Union[cp.ndarray, tupl
365394
"""
366395
full_shape_tomo = cp.shape(tomo)
367396

368-
pad_tup = []
369-
for index, element in enumerate(full_shape_tomo):
370-
if index == 0:
371-
pad_width = (0, 0) # do not pad the slicing dim
372-
else:
373-
diff = _shift_bit_length(element + 1) - element
374-
if element % 2 == 0:
375-
pad_width_scalar = diff // 2
376-
pad_width = (pad_width_scalar, pad_width_scalar)
377-
else:
378-
# need an uneven padding for odd-number lengths
379-
left_pad = diff // 2
380-
right_pad = diff - left_pad
381-
pad_width = (left_pad, right_pad)
382-
383-
pad_tup.append(pad_width)
397+
pad_list = _calculate_pad_size(full_shape_tomo)
384398

385-
padded_tomo = cp.pad(tomo, tuple(pad_tup), "edge")
399+
padded_tomo = cp.pad(tomo, tuple(pad_list), "edge")
386400

387-
return padded_tomo, pad_tup
401+
return padded_tomo, tuple(pad_list)
388402

389403

390404
def _paganin_filter_factor2(energy, dist, alpha, w2):

0 commit comments

Comments
 (0)