Skip to content

Commit cc0839a

Browse files
committed
Implement calc_peak_mem for paganin_filter
1 parent 5a99c88 commit cc0839a

File tree

2 files changed

+125
-20
lines changed

2 files changed

+125
-20
lines changed

httomolibgpu/prep/phase.py

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import numpy as np
2424
from httomolibgpu import cupywrapper
25+
from httomolibgpu.memory_estimator_helpers import _DeviceMemStack
2526

2627
cp = cupywrapper.cp
2728
cupy_run = cupywrapper.cupy_run
@@ -30,13 +31,14 @@
3031

3132
if cupy_run:
3233
from cupyx.scipy.fft import fft2, ifft2, fftshift
34+
from cupyx.scipy.fftpack import get_fft_plan
3335
else:
3436
fft2 = Mock()
3537
ifft2 = Mock()
3638
fftshift = Mock()
3739

3840
from numpy import float32
39-
from typing import Tuple
41+
from typing import Optional, Tuple
4042
import math
4143

4244
__all__ = [
@@ -54,6 +56,7 @@ def paganin_filter(
5456
distance: float = 1.0,
5557
energy: float = 53.0,
5658
ratio_delta_beta: float = 250,
59+
calc_peak_gpu_mem: bool = False,
5760
) -> cp.ndarray:
5861
"""
5962
Perform single-material phase retrieval from flats/darks corrected tomographic measurements. For more detailed information, see :ref:`phase_contrast_module`.
@@ -77,24 +80,42 @@ def paganin_filter(
7780
cp.ndarray
7881
The 3D array of Paganin phase-filtered projection images.
7982
"""
83+
mem_stack = _DeviceMemStack() if calc_peak_gpu_mem else None
8084
# Check the input data is valid
81-
if tomo.ndim != 3:
85+
if not mem_stack and tomo.ndim != 3:
8286
raise ValueError(
8387
f"Invalid number of dimensions in data: {tomo.ndim},"
8488
" please provide a stack of 2D projections."
8589
)
86-
87-
dz_orig, dy_orig, dx_orig = tomo.shape
90+
if mem_stack:
91+
mem_stack.malloc(np.prod(tomo) * np.float32().itemsize)
92+
dz_orig, dy_orig, dx_orig = tomo.shape if not mem_stack else tomo
8893

8994
# Perform padding to the power of 2 as FFT is O(n*log(n)) complexity
9095
# TODO: adding other options of padding?
91-
padded_tomo, pad_tup = _pad_projections_to_second_power(tomo)
96+
padded_tomo, pad_tup = _pad_projections_to_second_power(tomo, mem_stack)
9297

93-
dz, dy, dx = padded_tomo.shape
98+
dz, dy, dx = padded_tomo.shape if not mem_stack else padded_tomo
9499

95100
# 3D FFT of tomo data
96-
padded_tomo = cp.asarray(padded_tomo, dtype=cp.complex64)
97-
fft_tomo = fft2(padded_tomo, axes=(-2, -1), overwrite_x=True)
101+
if mem_stack:
102+
mem_stack.malloc(np.prod(padded_tomo) * np.complex64().itemsize)
103+
mem_stack.free(np.prod(padded_tomo) * np.float32().itemsize)
104+
fft_input = cp.empty(padded_tomo, dtype=cp.complex64)
105+
else:
106+
padded_tomo = cp.asarray(padded_tomo, dtype=cp.complex64)
107+
fft_input = padded_tomo
108+
109+
fft_plan = get_fft_plan(fft_input, axes=(-2, -1))
110+
if mem_stack:
111+
mem_stack.malloc(fft_plan.work_area.mem.size)
112+
mem_stack.free(fft_plan.work_area.mem.size)
113+
else:
114+
with fft_plan:
115+
fft_tomo = fft2(padded_tomo, axes=(-2, -1), overwrite_x=True)
116+
del padded_tomo
117+
del fft_input
118+
del fft_plan
98119

99120
# calculate alpha constant
100121
alpha = _calculate_alpha(energy, distance / 1e-6, ratio_delta_beta)
@@ -103,18 +124,41 @@ def paganin_filter(
103124
indx = _reciprocal_coord(pixel_size, dy)
104125
indy = _reciprocal_coord(pixel_size, dx)
105126

106-
# Build Lorentzian-type filter
107-
phase_filter = fftshift(
108-
1.0 / (1.0 + alpha * (cp.add.outer(cp.square(indx), cp.square(indy))))
109-
)
127+
if mem_stack:
128+
mem_stack.malloc(indx.size * indx.dtype.itemsize) # cp.asarray(indx)
129+
mem_stack.malloc(indx.size ** 2 * indx.dtype.itemsize) # cp.square
130+
mem_stack.malloc(indy.size * indy.dtype.itemsize) # cp.asarray(indy)
131+
mem_stack.malloc(indy.size ** 2 * indy.dtype.itemsize) # cp.square
132+
133+
mem_stack.malloc(indx.size ** 2 * indx.dtype.itemsize) # phase_filter
134+
135+
mem_stack.free(indx.size * indx.dtype.itemsize)
136+
mem_stack.free(indx.size ** 2 * indx.dtype.itemsize)
137+
mem_stack.free(indy.size * indy.dtype.itemsize)
138+
mem_stack.free(indy.size ** 2 * indy.dtype.itemsize)
139+
else:
140+
# Build Lorentzian-type filter
141+
phase_filter = fftshift(
142+
1.0 / (1.0 + alpha * (cp.add.outer(cp.square(cp.asarray(indx)), cp.square(cp.asarray(indy)))))
143+
)
110144

111-
phase_filter = phase_filter / phase_filter.max() # normalisation
145+
phase_filter = phase_filter / phase_filter.max() # normalisation
112146

113-
# Filter projections
114-
fft_tomo *= phase_filter
147+
# Filter projections
148+
fft_tomo *= phase_filter
115149

116150
# Apply filter and take inverse FFT
117-
ifft_filtered_tomo = ifft2(fft_tomo, axes=(-2, -1), overwrite_x=True).real
151+
ifft_input = fft_tomo if not mem_stack else cp.empty(padded_tomo, dtype=cp.complex64)
152+
ifft_plan = get_fft_plan(ifft_input, axes=(-2, -1))
153+
if mem_stack:
154+
mem_stack.malloc(ifft_plan.work_area.mem.size)
155+
mem_stack.free(ifft_plan.work_area.mem.size)
156+
else:
157+
with ifft_plan:
158+
ifft_filtered_tomo = ifft2(fft_tomo, axes=(-2, -1), overwrite_x=True).real
159+
del fft_tomo
160+
del ifft_plan
161+
del ifft_input
118162

119163
# slicing indices for cropping
120164
slc_indices = (
@@ -123,8 +167,15 @@ def paganin_filter(
123167
slice(pad_tup[2][0], pad_tup[2][0] + dx_orig, 1),
124168
)
125169

170+
if mem_stack:
171+
mem_stack.malloc(np.prod(tomo) * np.float32().itemsize) # astype(cp.float32)
172+
mem_stack.free(np.prod(padded_tomo) * np.complex64().itemsize) # ifft_filtered_tomo
173+
mem_stack.malloc(np.prod(tomo) * np.float32().itemsize) # return _log_kernel(tomo)
174+
return mem_stack.highwater
175+
126176
# crop the padded filtered data:
127177
tomo = ifft_filtered_tomo[slc_indices].astype(cp.float32)
178+
del ifft_filtered_tomo
128179

129180
# taking the negative log
130181
_log_kernel = cp.ElementwiseKernel(
@@ -178,6 +229,7 @@ def _calculate_pad_size(datashape: tuple) -> list:
178229

179230
def _pad_projections_to_second_power(
180231
tomo: cp.ndarray,
232+
mem_stack: Optional[_DeviceMemStack]
181233
) -> Tuple[cp.ndarray, Tuple[int, int]]:
182234
"""
183235
Performs padding of each projection to the next power of 2.
@@ -194,11 +246,15 @@ def _pad_projections_to_second_power(
194246
ndarray: padded 3d projection data
195247
tuple: a tuple with padding dimensions
196248
"""
197-
full_shape_tomo = cp.shape(tomo)
249+
full_shape_tomo = cp.shape(tomo) if not mem_stack else tomo
198250

199251
pad_list = _calculate_pad_size(full_shape_tomo)
200252

201-
padded_tomo = cp.pad(tomo, tuple(pad_list), "edge")
253+
if mem_stack:
254+
padded_tomo = [sh + pad[0] + pad[1] for sh, pad in zip(full_shape_tomo, pad_list)]
255+
mem_stack.malloc(np.prod(padded_tomo) * np.float32().itemsize)
256+
else:
257+
padded_tomo = cp.pad(tomo, tuple(pad_list), "edge")
202258

203259
return padded_tomo, tuple(pad_list)
204260

@@ -209,7 +265,7 @@ def _wavelength_micron(energy: float) -> float:
209265
return 2 * math.pi * PLANCK_CONSTANT * SPEED_OF_LIGHT / energy
210266

211267

212-
def _reciprocal_coord(pixel_size: float, num_grid: int) -> cp.ndarray:
268+
def _reciprocal_coord(pixel_size: float, num_grid: int) -> np.ndarray:
213269
"""
214270
Calculate reciprocal grid coordinates for a given pixel size
215271
and discretization.
@@ -227,7 +283,7 @@ def _reciprocal_coord(pixel_size: float, num_grid: int) -> cp.ndarray:
227283
Grid coordinates.
228284
"""
229285
n = num_grid - 1
230-
rc = cp.arange(-n, num_grid, 2, dtype=cp.float32)
286+
rc = np.arange(-n, num_grid, 2, dtype=cp.float32)
231287
rc *= 2 * math.pi / (n * pixel_size)
232288
return rc
233289

tests/test_prep/test_phase.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from httomolibgpu.prep.phase import paganin_filter
88
from numpy.testing import assert_allclose
99

10+
from ..conftest import MaxMemoryHook
11+
1012
eps = 1e-6
1113

1214

@@ -81,3 +83,50 @@ def test_paganin_filter_performance(ensure_clean_memory):
8183
duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10
8284

8385
assert "performance in ms" == duration_ms
86+
87+
@pytest.mark.parametrize("slices", [3, 7, 32, 61, 109, 120, 150])
88+
@pytest.mark.parametrize("dim_x", [128, 140])
89+
def test_paganin_filter_calc_mem(slices, dim_x, ensure_clean_memory):
90+
dim_y = 159
91+
data = cp.random.random_sample((slices, dim_x, dim_y), dtype=np.float32)
92+
hook = MaxMemoryHook()
93+
with hook:
94+
paganin_filter(cp.copy(data))
95+
actual_mem_peak = hook.max_mem
96+
97+
try:
98+
estimated_mem_peak = paganin_filter(
99+
data.shape, calc_peak_gpu_mem=True
100+
)
101+
except cp.cuda.memory.OutOfMemoryError:
102+
pytest.skip("Not enough GPU memory to estimate memory peak")
103+
104+
assert actual_mem_peak * 0.99 <= estimated_mem_peak
105+
assert estimated_mem_peak <= actual_mem_peak * 1.01
106+
107+
108+
@pytest.mark.parametrize(
109+
"slices", [38, 177, 268, 320, 490, 607, 803, 859, 902, 951]
110+
)
111+
@pytest.mark.parametrize("dims", [(900, 1280), (1801, 1540), (1801, 2560)])
112+
def test_paganin_filter_calc_mem_big(slices, dims, ensure_clean_memory):
113+
dim_y, dim_x = dims
114+
data_shape = (slices, dim_x, dim_y)
115+
try:
116+
estimated_mem_peak = paganin_filter(
117+
data_shape, calc_peak_gpu_mem=True
118+
)
119+
except cp.cuda.memory.OutOfMemoryError:
120+
pytest.skip("Not enough GPU memory to estimate memory peak")
121+
av_mem = cp.cuda.Device().mem_info[0]
122+
if av_mem < estimated_mem_peak:
123+
pytest.skip("Not enough GPU memory to run this test")
124+
125+
hook = MaxMemoryHook()
126+
with hook:
127+
data = cp.random.random_sample(data_shape, dtype=np.float32)
128+
paganin_filter(data)
129+
actual_mem_peak = hook.max_mem
130+
131+
assert actual_mem_peak * 0.99 <= estimated_mem_peak
132+
assert estimated_mem_peak <= actual_mem_peak * 1.01

0 commit comments

Comments
 (0)