Skip to content

Commit a841b54

Browse files
committed
adding docstring and linting
1 parent 9737251 commit a841b54

File tree

3 files changed

+60
-54
lines changed

3 files changed

+60
-54
lines changed

httomolibgpu/prep/phase.py

Lines changed: 47 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def paganin_filter(
7474
Beam energy in keV.
7575
ratio_delta_beta : float
7676
The ratio of delta/beta, where delta is the phase shift and real part of the complex material refractive index and beta is the absorption.
77+
calc_peak_gpu_mem: bool
78+
Parameter to support memory estimation in HTTomo. Irrelevant to the method itself and can be ignored by user.
7779
7880
Returns
7981
-------
@@ -125,24 +127,33 @@ def paganin_filter(
125127
indy = _reciprocal_coord(pixel_size, dx)
126128

127129
if mem_stack:
128-
mem_stack.malloc(indx.size * indx.dtype.itemsize) # cp.asarray(indx)
129-
mem_stack.malloc(indx.size * indx.dtype.itemsize) # cp.square
130-
mem_stack.free(indx.size * indx.dtype.itemsize) # cp.asarray(indx)
130+
mem_stack.malloc(indx.size * indx.dtype.itemsize) # cp.asarray(indx)
131+
mem_stack.malloc(indx.size * indx.dtype.itemsize) # cp.square
132+
mem_stack.free(indx.size * indx.dtype.itemsize) # cp.asarray(indx)
131133
mem_stack.malloc(indy.size * indy.dtype.itemsize) # cp.asarray(indy)
132-
mem_stack.malloc(indy.size * indy.dtype.itemsize) # cp.square
134+
mem_stack.malloc(indy.size * indy.dtype.itemsize) # cp.square
133135
mem_stack.free(indy.size * indy.dtype.itemsize) # cp.asarray(indy)
134136

135-
mem_stack.malloc(indx.size * indy.size * indx.dtype.itemsize) # cp.add.outer
136-
mem_stack.free(indx.size * indx.dtype.itemsize) # cp.square
137-
mem_stack.free(indy.size * indy.dtype.itemsize) # cp.square
138-
mem_stack.malloc(indx.size * indy.size * indx.dtype.itemsize) # phase_filter
139-
mem_stack.free(indx.size * indy.size * indx.dtype.itemsize) # cp.add.outer
140-
mem_stack.free(indx.size * indy.size * indx.dtype.itemsize) # phase_filter
137+
mem_stack.malloc(indx.size * indy.size * indx.dtype.itemsize) # cp.add.outer
138+
mem_stack.free(indx.size * indx.dtype.itemsize) # cp.square
139+
mem_stack.free(indy.size * indy.dtype.itemsize) # cp.square
140+
mem_stack.malloc(indx.size * indy.size * indx.dtype.itemsize) # phase_filter
141+
mem_stack.free(indx.size * indy.size * indx.dtype.itemsize) # cp.add.outer
142+
mem_stack.free(indx.size * indy.size * indx.dtype.itemsize) # phase_filter
141143

142144
else:
143145
# Build Lorentzian-type filter
144146
phase_filter = fftshift(
145-
1.0 / (1.0 + alpha * (cp.add.outer(cp.square(cp.asarray(indx)), cp.square(cp.asarray(indy)))))
147+
1.0
148+
/ (
149+
1.0
150+
+ alpha
151+
* (
152+
cp.add.outer(
153+
cp.square(cp.asarray(indx)), cp.square(cp.asarray(indy))
154+
)
155+
)
156+
)
146157
)
147158

148159
phase_filter = phase_filter / phase_filter.max() # normalisation
@@ -152,15 +163,17 @@ def paganin_filter(
152163
del phase_filter
153164

154165
# Apply filter and take inverse FFT
155-
ifft_input = fft_tomo if not mem_stack else cp.empty(padded_tomo, dtype=cp.complex64)
166+
ifft_input = (
167+
fft_tomo if not mem_stack else cp.empty(padded_tomo, dtype=cp.complex64)
168+
)
156169
ifft_plan = get_fft_plan(ifft_input, axes=(-2, -1))
157170
if mem_stack:
158171
mem_stack.malloc(ifft_plan.work_area.mem.size)
159172
mem_stack.free(ifft_plan.work_area.mem.size)
160173
else:
161174
with ifft_plan:
162175
ifft_filtered_tomo = ifft2(fft_tomo, axes=(-2, -1), overwrite_x=True).real
163-
del fft_tomo
176+
del fft_tomo
164177
del ifft_plan
165178
del ifft_input
166179

@@ -172,9 +185,13 @@ def paganin_filter(
172185
)
173186

174187
if mem_stack:
175-
mem_stack.malloc(np.prod(tomo) * np.float32().itemsize) # astype(cp.float32)
176-
mem_stack.free(np.prod(padded_tomo) * np.complex64().itemsize) # ifft_filtered_tomo
177-
mem_stack.malloc(np.prod(tomo) * np.float32().itemsize) # return _log_kernel(tomo)
188+
mem_stack.malloc(np.prod(tomo) * np.float32().itemsize) # astype(cp.float32)
189+
mem_stack.free(
190+
np.prod(padded_tomo) * np.complex64().itemsize
191+
) # ifft_filtered_tomo
192+
mem_stack.malloc(
193+
np.prod(tomo) * np.float32().itemsize
194+
) # return _log_kernel(tomo)
178195
return mem_stack.highwater
179196

180197
# crop the padded filtered data:
@@ -232,8 +249,7 @@ def _calculate_pad_size(datashape: tuple) -> list:
232249

233250

234251
def _pad_projections_to_second_power(
235-
tomo: cp.ndarray,
236-
mem_stack: Optional[_DeviceMemStack]
252+
tomo: cp.ndarray, mem_stack: Optional[_DeviceMemStack]
237253
) -> Tuple[cp.ndarray, Tuple[int, int]]:
238254
"""
239255
Performs padding of each projection to the next power of 2.
@@ -255,7 +271,9 @@ def _pad_projections_to_second_power(
255271
pad_list = _calculate_pad_size(full_shape_tomo)
256272

257273
if mem_stack:
258-
padded_tomo = [sh + pad[0] + pad[1] for sh, pad in zip(full_shape_tomo, pad_list)]
274+
padded_tomo = [
275+
sh + pad[0] + pad[1] for sh, pad in zip(full_shape_tomo, pad_list)
276+
]
259277
mem_stack.malloc(np.prod(padded_tomo) * np.float32().itemsize)
260278
else:
261279
padded_tomo = cp.pad(tomo, tuple(pad_list), "edge")
@@ -317,11 +335,20 @@ def paganin_filter_savu_legacy(
317335
Beam energy in keV.
318336
ratio_delta_beta : float
319337
The ratio of delta/beta, where delta is the phase shift and real part of the complex material refractive index and beta is the absorption.
338+
calc_peak_gpu_mem: bool
339+
Parameter to support memory estimation in HTTomo. Irrelevant to the method itself and can be ignored by user.
320340
321341
Returns
322342
-------
323343
cp.ndarray
324344
The 3D array of Paganin phase-filtered projection images.
325345
"""
326346

327-
return paganin_filter(tomo, pixel_size, distance, energy, ratio_delta_beta / 4, calc_peak_gpu_mem=calc_peak_gpu_mem)
347+
return paganin_filter(
348+
tomo,
349+
pixel_size,
350+
distance,
351+
energy,
352+
ratio_delta_beta / 4,
353+
calc_peak_gpu_mem=calc_peak_gpu_mem,
354+
)

httomolibgpu/recon/_phase_cross_correlation.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@
3636
import cupyx.scipy.ndimage as ndi
3737
import numpy as np
3838

39-
def _upsampled_dft(
40-
data, upsampled_region_size, upsample_factor=1, axis_offsets=None
41-
):
39+
40+
def _upsampled_dft(data, upsampled_region_size, upsample_factor=1, axis_offsets=None):
4241
"""
4342
Upsampled DFT by matrix multiplication.
4443
@@ -148,9 +147,7 @@ def _compute_error(cross_correlation_max, src_amp, target_amp):
148147
)
149148

150149
with np.errstate(invalid="ignore"):
151-
error = 1.0 - cross_correlation_max * cross_correlation_max.conj() / (
152-
amp
153-
)
150+
error = 1.0 - cross_correlation_max * cross_correlation_max.conj() / (amp)
154151

155152
return cp.sqrt(cp.abs(error))
156153

@@ -192,9 +189,7 @@ def _disambiguate_shift(reference_image, moving_image, shift):
192189
negative_shift = [shift_i - s for shift_i, s in zip(positive_shift, shape)]
193190
subpixel = any(s % 1 != 0 for s in shift)
194191
interp_order = 3 if subpixel else 0
195-
shifted = ndi.shift(
196-
moving_image, shift, mode="grid-wrap", order=interp_order
197-
)
192+
shifted = ndi.shift(moving_image, shift, mode="grid-wrap", order=interp_order)
198193
indices = tuple(round(s) for s in positive_shift)
199194
splits_per_dim = [(slice(0, i), slice(i, None)) for i in indices]
200195
max_corr = -1.0
@@ -217,9 +212,7 @@ def _disambiguate_shift(reference_image, moving_image, shift):
217212
)
218213
return shift
219214
real_shift_acc = []
220-
for sl, pos_shift, neg_shift in zip(
221-
max_slice, positive_shift, negative_shift
222-
):
215+
for sl, pos_shift, neg_shift in zip(max_slice, positive_shift, negative_shift):
223216
real_shift_acc.append(pos_shift if sl.stop is None else neg_shift)
224217
if not subpixel:
225218
real_shift = tuple(map(int, real_shift_acc))
@@ -359,16 +352,12 @@ def phase_cross_correlation(
359352
# Initial shift estimate in upsampled grid
360353
# shift = cp.around(shift * upsample_factor) / upsample_factor
361354
upsample_factor = float(upsample_factor)
362-
shift = tuple(
363-
round(s * upsample_factor) / upsample_factor for s in shift
364-
)
355+
shift = tuple(round(s * upsample_factor) / upsample_factor for s in shift)
365356
upsampled_region_size = math.ceil(upsample_factor * 1.5)
366357
# Center of output array at dftshift + 1
367358
dftshift = float(upsampled_region_size // 2)
368359
# Matrix multiply DFT around the current shift estimate
369-
sample_region_offset = tuple(
370-
dftshift - s * upsample_factor for s in shift
371-
)
360+
sample_region_offset = tuple(dftshift - s * upsample_factor for s in shift)
372361
cross_correlation = _upsampled_dft(
373362
image_product.conj(),
374363
upsampled_region_size,
@@ -394,9 +383,7 @@ def phase_cross_correlation(
394383

395384
# If its only one row or column the shift along that dimension has no
396385
# effect. We set to zero.
397-
shift = tuple(
398-
s if axis_size != 1 else 0 for s, axis_size in zip(shift, shape)
399-
)
386+
shift = tuple(s if axis_size != 1 else 0 for s, axis_size in zip(shift, shape))
400387

401388
if disambiguate:
402389
if space.lower() != "real":
@@ -406,10 +393,7 @@ def phase_cross_correlation(
406393

407394
# Redirect user to masked_phase_cross_correlation if NaNs are observed
408395
if cp.isnan(CCmax) or cp.isnan(src_amp) or cp.isnan(target_amp):
409-
raise ValueError(
410-
"NaN values found, please remove NaNs from your "
411-
"input data"
412-
)
396+
raise ValueError("NaN values found, please remove NaNs from your " "input data")
413397

414398
return (
415399
shift,

tests/test_prep/test_phase.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def test_paganin_filter_performance(ensure_clean_memory):
8484

8585
assert "performance in ms" == duration_ms
8686

87+
8788
@pytest.mark.parametrize("slices", [3, 7, 32, 61, 109, 120, 150])
8889
@pytest.mark.parametrize("dim_x", [128, 140])
8990
def test_paganin_filter_calc_mem(slices, dim_x, ensure_clean_memory):
@@ -95,27 +96,21 @@ def test_paganin_filter_calc_mem(slices, dim_x, ensure_clean_memory):
9596
actual_mem_peak = hook.max_mem
9697

9798
try:
98-
estimated_mem_peak = paganin_filter(
99-
data.shape, calc_peak_gpu_mem=True
100-
)
99+
estimated_mem_peak = paganin_filter(data.shape, calc_peak_gpu_mem=True)
101100
except cp.cuda.memory.OutOfMemoryError:
102101
pytest.skip("Not enough GPU memory to estimate memory peak")
103102

104103
assert actual_mem_peak * 0.99 <= estimated_mem_peak
105104
assert estimated_mem_peak <= actual_mem_peak * 1.01
106105

107106

108-
@pytest.mark.parametrize(
109-
"slices", [38, 177, 268, 320, 490, 607, 803, 859, 902, 951]
110-
)
107+
@pytest.mark.parametrize("slices", [38, 177, 268, 320, 490, 607, 803, 859, 902, 951])
111108
@pytest.mark.parametrize("dims", [(900, 1280), (1801, 1540), (1801, 2560)])
112109
def test_paganin_filter_calc_mem_big(slices, dims, ensure_clean_memory):
113110
dim_y, dim_x = dims
114111
data_shape = (slices, dim_x, dim_y)
115112
try:
116-
estimated_mem_peak = paganin_filter(
117-
data_shape, calc_peak_gpu_mem=True
118-
)
113+
estimated_mem_peak = paganin_filter(data_shape, calc_peak_gpu_mem=True)
119114
except cp.cuda.memory.OutOfMemoryError:
120115
pytest.skip("Not enough GPU memory to estimate memory peak")
121116
av_mem = cp.cuda.Device().mem_info[0]

0 commit comments

Comments
 (0)