Skip to content

Commit cfe029c

Browse files
authored
Merge pull request #164 from DiamondLightSource/phase_filter
moving the padding estimator into a separate function for Paganin
2 parents 192da31 + b5abb5e commit cfe029c

File tree

1 file changed

+37
-20
lines changed

1 file changed

+37
-20
lines changed

httomolibgpu/prep/phase.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
fftshift = Mock()
3939

4040
from numpy import float32
41-
from typing import Union
41+
from typing import Tuple
4242
import math
4343

4444
__all__ = [
@@ -348,25 +348,17 @@ def _shift_bit_length(x: int) -> int:
348348
return 1 << (x - 1).bit_length()
349349

350350

351-
def _pad_projections_to_second_power(tomo: cp.ndarray) -> Union[cp.ndarray, tuple]:
352-
"""
353-
Performs padding of each projection to the next power of 2.
354-
If the shape is not even we also care of that before padding.
351+
def _calculate_pad_size(datashape: tuple) -> list:
352+
"""Calculating the padding size
355353
356-
Parameters
357-
----------
358-
tomo : cp.ndarray
359-
3d projection data
354+
Args:
355+
datashape (tuple): the shape of the 3D data
360356
361-
Returns
362-
-------
363-
ndarray: padded 3d projection data
364-
tuple: a tuple with padding dimensions
357+
Returns:
358+
list: the padded dimensions
365359
"""
366-
full_shape_tomo = cp.shape(tomo)
367-
368-
pad_tup = []
369-
for index, element in enumerate(full_shape_tomo):
360+
pad_list = []
361+
for index, element in enumerate(datashape):
370362
if index == 0:
371363
pad_width = (0, 0) # do not pad the slicing dim
372364
else:
@@ -380,11 +372,36 @@ def _pad_projections_to_second_power(tomo: cp.ndarray) -> Union[cp.ndarray, tupl
380372
right_pad = diff - left_pad
381373
pad_width = (left_pad, right_pad)
382374

383-
pad_tup.append(pad_width)
375+
pad_list.append(pad_width)
376+
377+
return pad_list
378+
379+
380+
def _pad_projections_to_second_power(
381+
tomo: cp.ndarray,
382+
) -> Tuple[cp.ndarray, Tuple[int, int]]:
383+
"""
384+
Performs padding of each projection to the next power of 2.
385+
If the shape is not even we also care of that before padding.
386+
387+
Parameters
388+
----------
389+
tomo : cp.ndarray
390+
3d projection data
391+
392+
Returns
393+
-------
394+
Tuple consisting of:
395+
ndarray: padded 3d projection data
396+
tuple: a tuple with padding dimensions
397+
"""
398+
full_shape_tomo = cp.shape(tomo)
399+
400+
pad_list = _calculate_pad_size(full_shape_tomo)
384401

385-
padded_tomo = cp.pad(tomo, tuple(pad_tup), "edge")
402+
padded_tomo = cp.pad(tomo, tuple(pad_list), "edge")
386403

387-
return padded_tomo, pad_tup
404+
return padded_tomo, tuple(pad_list)
388405

389406

390407
def _paganin_filter_factor2(energy, dist, alpha, w2):

0 commit comments

Comments
 (0)