Skip to content

Commit e807952

Browse files
authored
Merge pull request #244 from DiamondLightSource/paganin-iterative-mem-est-radway-71
Inline memory peak calculator for Paganin filter
2 parents 8bcc5fc + a841b54 commit e807952

File tree

6 files changed

+208
-47
lines changed

6 files changed

+208
-47
lines changed
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
ALLOCATION_UNIT_SIZE = 512
2+
3+
4+
class _DeviceMemStack:
5+
def __init__(self) -> None:
6+
self.allocations = []
7+
self.current = 0
8+
self.highwater = 0
9+
10+
def malloc(self, bytes):
11+
self.allocations.append(bytes)
12+
allocated = self._round_up(bytes)
13+
self.current += allocated
14+
self.highwater = max(self.current, self.highwater)
15+
16+
def free(self, bytes):
17+
assert bytes in self.allocations
18+
self.allocations.remove(bytes)
19+
self.current -= self._round_up(bytes)
20+
assert self.current >= 0
21+
22+
def _round_up(self, size):
23+
size = (size + ALLOCATION_UNIT_SIZE - 1) // ALLOCATION_UNIT_SIZE
24+
return size * ALLOCATION_UNIT_SIZE

httomolibgpu/prep/phase.py

Lines changed: 110 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import numpy as np
2424
from httomolibgpu import cupywrapper
25+
from httomolibgpu.memory_estimator_helpers import _DeviceMemStack
2526

2627
cp = cupywrapper.cp
2728
cupy_run = cupywrapper.cupy_run
@@ -30,13 +31,14 @@
3031

3132
if cupy_run:
3233
from cupyx.scipy.fft import fft2, ifft2, fftshift
34+
from cupyx.scipy.fftpack import get_fft_plan
3335
else:
3436
fft2 = Mock()
3537
ifft2 = Mock()
3638
fftshift = Mock()
3739

3840
from numpy import float32
39-
from typing import Tuple
41+
from typing import Optional, Tuple
4042
import math
4143

4244
__all__ = [
@@ -54,6 +56,7 @@ def paganin_filter(
5456
distance: float = 1.0,
5557
energy: float = 53.0,
5658
ratio_delta_beta: float = 250,
59+
calc_peak_gpu_mem: bool = False,
5760
) -> cp.ndarray:
5861
"""
5962
Perform single-material phase retrieval from flats/darks corrected tomographic measurements. For more detailed information, see :ref:`phase_contrast_module`.
@@ -71,30 +74,50 @@ def paganin_filter(
7174
Beam energy in keV.
7275
ratio_delta_beta : float
7376
The ratio of delta/beta, where delta is the phase shift and real part of the complex material refractive index and beta is the absorption.
77+
calc_peak_gpu_mem: bool
78+
Parameter to support memory estimation in HTTomo. Irrelevant to the method itself and can be ignored by user.
7479
7580
Returns
7681
-------
7782
cp.ndarray
7883
The 3D array of Paganin phase-filtered projection images.
7984
"""
85+
mem_stack = _DeviceMemStack() if calc_peak_gpu_mem else None
8086
# Check the input data is valid
81-
if tomo.ndim != 3:
87+
if not mem_stack and tomo.ndim != 3:
8288
raise ValueError(
8389
f"Invalid number of dimensions in data: {tomo.ndim},"
8490
" please provide a stack of 2D projections."
8591
)
86-
87-
dz_orig, dy_orig, dx_orig = tomo.shape
92+
if mem_stack:
93+
mem_stack.malloc(np.prod(tomo) * np.float32().itemsize)
94+
dz_orig, dy_orig, dx_orig = tomo.shape if not mem_stack else tomo
8895

8996
# Perform padding to the power of 2 as FFT is O(n*log(n)) complexity
9097
# TODO: adding other options of padding?
91-
padded_tomo, pad_tup = _pad_projections_to_second_power(tomo)
98+
padded_tomo, pad_tup = _pad_projections_to_second_power(tomo, mem_stack)
9299

93-
dz, dy, dx = padded_tomo.shape
100+
dz, dy, dx = padded_tomo.shape if not mem_stack else padded_tomo
94101

95102
# 3D FFT of tomo data
96-
padded_tomo = cp.asarray(padded_tomo, dtype=cp.complex64)
97-
fft_tomo = fft2(padded_tomo, axes=(-2, -1), overwrite_x=True)
103+
if mem_stack:
104+
mem_stack.malloc(np.prod(padded_tomo) * np.complex64().itemsize)
105+
mem_stack.free(np.prod(padded_tomo) * np.float32().itemsize)
106+
fft_input = cp.empty(padded_tomo, dtype=cp.complex64)
107+
else:
108+
padded_tomo = cp.asarray(padded_tomo, dtype=cp.complex64)
109+
fft_input = padded_tomo
110+
111+
fft_plan = get_fft_plan(fft_input, axes=(-2, -1))
112+
if mem_stack:
113+
mem_stack.malloc(fft_plan.work_area.mem.size)
114+
mem_stack.free(fft_plan.work_area.mem.size)
115+
else:
116+
with fft_plan:
117+
fft_tomo = fft2(padded_tomo, axes=(-2, -1), overwrite_x=True)
118+
del padded_tomo
119+
del fft_input
120+
del fft_plan
98121

99122
# calculate alpha constant
100123
alpha = _calculate_alpha(energy, distance / 1e-6, ratio_delta_beta)
@@ -103,18 +126,56 @@ def paganin_filter(
103126
indx = _reciprocal_coord(pixel_size, dy)
104127
indy = _reciprocal_coord(pixel_size, dx)
105128

106-
# Build Lorentzian-type filter
107-
phase_filter = fftshift(
108-
1.0 / (1.0 + alpha * (cp.add.outer(cp.square(indx), cp.square(indy))))
109-
)
129+
if mem_stack:
130+
mem_stack.malloc(indx.size * indx.dtype.itemsize) # cp.asarray(indx)
131+
mem_stack.malloc(indx.size * indx.dtype.itemsize) # cp.square
132+
mem_stack.free(indx.size * indx.dtype.itemsize) # cp.asarray(indx)
133+
mem_stack.malloc(indy.size * indy.dtype.itemsize) # cp.asarray(indy)
134+
mem_stack.malloc(indy.size * indy.dtype.itemsize) # cp.square
135+
mem_stack.free(indy.size * indy.dtype.itemsize) # cp.asarray(indy)
136+
137+
mem_stack.malloc(indx.size * indy.size * indx.dtype.itemsize) # cp.add.outer
138+
mem_stack.free(indx.size * indx.dtype.itemsize) # cp.square
139+
mem_stack.free(indy.size * indy.dtype.itemsize) # cp.square
140+
mem_stack.malloc(indx.size * indy.size * indx.dtype.itemsize) # phase_filter
141+
mem_stack.free(indx.size * indy.size * indx.dtype.itemsize) # cp.add.outer
142+
mem_stack.free(indx.size * indy.size * indx.dtype.itemsize) # phase_filter
143+
144+
else:
145+
# Build Lorentzian-type filter
146+
phase_filter = fftshift(
147+
1.0
148+
/ (
149+
1.0
150+
+ alpha
151+
* (
152+
cp.add.outer(
153+
cp.square(cp.asarray(indx)), cp.square(cp.asarray(indy))
154+
)
155+
)
156+
)
157+
)
110158

111-
phase_filter = phase_filter / phase_filter.max() # normalisation
159+
phase_filter = phase_filter / phase_filter.max() # normalisation
112160

113-
# Filter projections
114-
fft_tomo *= phase_filter
161+
# Filter projections
162+
fft_tomo *= phase_filter
163+
del phase_filter
115164

116165
# Apply filter and take inverse FFT
117-
ifft_filtered_tomo = ifft2(fft_tomo, axes=(-2, -1), overwrite_x=True).real
166+
ifft_input = (
167+
fft_tomo if not mem_stack else cp.empty(padded_tomo, dtype=cp.complex64)
168+
)
169+
ifft_plan = get_fft_plan(ifft_input, axes=(-2, -1))
170+
if mem_stack:
171+
mem_stack.malloc(ifft_plan.work_area.mem.size)
172+
mem_stack.free(ifft_plan.work_area.mem.size)
173+
else:
174+
with ifft_plan:
175+
ifft_filtered_tomo = ifft2(fft_tomo, axes=(-2, -1), overwrite_x=True).real
176+
del fft_tomo
177+
del ifft_plan
178+
del ifft_input
118179

119180
# slicing indices for cropping
120181
slc_indices = (
@@ -123,8 +184,19 @@ def paganin_filter(
123184
slice(pad_tup[2][0], pad_tup[2][0] + dx_orig, 1),
124185
)
125186

187+
if mem_stack:
188+
mem_stack.malloc(np.prod(tomo) * np.float32().itemsize) # astype(cp.float32)
189+
mem_stack.free(
190+
np.prod(padded_tomo) * np.complex64().itemsize
191+
) # ifft_filtered_tomo
192+
mem_stack.malloc(
193+
np.prod(tomo) * np.float32().itemsize
194+
) # return _log_kernel(tomo)
195+
return mem_stack.highwater
196+
126197
# crop the padded filtered data:
127198
tomo = ifft_filtered_tomo[slc_indices].astype(cp.float32)
199+
del ifft_filtered_tomo
128200

129201
# taking the negative log
130202
_log_kernel = cp.ElementwiseKernel(
@@ -177,7 +249,7 @@ def _calculate_pad_size(datashape: tuple) -> list:
177249

178250

179251
def _pad_projections_to_second_power(
180-
tomo: cp.ndarray,
252+
tomo: cp.ndarray, mem_stack: Optional[_DeviceMemStack]
181253
) -> Tuple[cp.ndarray, Tuple[int, int]]:
182254
"""
183255
Performs padding of each projection to the next power of 2.
@@ -194,11 +266,17 @@ def _pad_projections_to_second_power(
194266
ndarray: padded 3d projection data
195267
tuple: a tuple with padding dimensions
196268
"""
197-
full_shape_tomo = cp.shape(tomo)
269+
full_shape_tomo = cp.shape(tomo) if not mem_stack else tomo
198270

199271
pad_list = _calculate_pad_size(full_shape_tomo)
200272

201-
padded_tomo = cp.pad(tomo, tuple(pad_list), "edge")
273+
if mem_stack:
274+
padded_tomo = [
275+
sh + pad[0] + pad[1] for sh, pad in zip(full_shape_tomo, pad_list)
276+
]
277+
mem_stack.malloc(np.prod(padded_tomo) * np.float32().itemsize)
278+
else:
279+
padded_tomo = cp.pad(tomo, tuple(pad_list), "edge")
202280

203281
return padded_tomo, tuple(pad_list)
204282

@@ -209,7 +287,7 @@ def _wavelength_micron(energy: float) -> float:
209287
return 2 * math.pi * PLANCK_CONSTANT * SPEED_OF_LIGHT / energy
210288

211289

212-
def _reciprocal_coord(pixel_size: float, num_grid: int) -> cp.ndarray:
290+
def _reciprocal_coord(pixel_size: float, num_grid: int) -> np.ndarray:
213291
"""
214292
Calculate reciprocal grid coordinates for a given pixel size
215293
and discretization.
@@ -227,7 +305,7 @@ def _reciprocal_coord(pixel_size: float, num_grid: int) -> cp.ndarray:
227305
Grid coordinates.
228306
"""
229307
n = num_grid - 1
230-
rc = cp.arange(-n, num_grid, 2, dtype=cp.float32)
308+
rc = np.arange(-n, num_grid, 2, dtype=cp.float32)
231309
rc *= 2 * math.pi / (n * pixel_size)
232310
return rc
233311

@@ -238,6 +316,7 @@ def paganin_filter_savu_legacy(
238316
distance: float = 1.0,
239317
energy: float = 53.0,
240318
ratio_delta_beta: float = 250,
319+
calc_peak_gpu_mem: bool = False,
241320
) -> cp.ndarray:
242321
"""
243322
Perform single-material phase retrieval from flats/darks corrected tomographic measurements. For more detailed information, see :ref:`phase_contrast_module`.
@@ -256,11 +335,20 @@ def paganin_filter_savu_legacy(
256335
Beam energy in keV.
257336
ratio_delta_beta : float
258337
The ratio of delta/beta, where delta is the phase shift and real part of the complex material refractive index and beta is the absorption.
338+
calc_peak_gpu_mem: bool
339+
Parameter to support memory estimation in HTTomo. Irrelevant to the method itself and can be ignored by user.
259340
260341
Returns
261342
-------
262343
cp.ndarray
263344
The 3D array of Paganin phase-filtered projection images.
264345
"""
265346

266-
return paganin_filter(tomo, pixel_size, distance, energy, ratio_delta_beta / 4)
347+
return paganin_filter(
348+
tomo,
349+
pixel_size,
350+
distance,
351+
energy,
352+
ratio_delta_beta / 4,
353+
calc_peak_gpu_mem=calc_peak_gpu_mem,
354+
)

httomolibgpu/recon/_phase_cross_correlation.py

Lines changed: 9 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@
3636
import cupyx.scipy.ndimage as ndi
3737
import numpy as np
3838

39-
def _upsampled_dft(
40-
data, upsampled_region_size, upsample_factor=1, axis_offsets=None
41-
):
39+
40+
def _upsampled_dft(data, upsampled_region_size, upsample_factor=1, axis_offsets=None):
4241
"""
4342
Upsampled DFT by matrix multiplication.
4443
@@ -148,9 +147,7 @@ def _compute_error(cross_correlation_max, src_amp, target_amp):
148147
)
149148

150149
with np.errstate(invalid="ignore"):
151-
error = 1.0 - cross_correlation_max * cross_correlation_max.conj() / (
152-
amp
153-
)
150+
error = 1.0 - cross_correlation_max * cross_correlation_max.conj() / (amp)
154151

155152
return cp.sqrt(cp.abs(error))
156153

@@ -192,9 +189,7 @@ def _disambiguate_shift(reference_image, moving_image, shift):
192189
negative_shift = [shift_i - s for shift_i, s in zip(positive_shift, shape)]
193190
subpixel = any(s % 1 != 0 for s in shift)
194191
interp_order = 3 if subpixel else 0
195-
shifted = ndi.shift(
196-
moving_image, shift, mode="grid-wrap", order=interp_order
197-
)
192+
shifted = ndi.shift(moving_image, shift, mode="grid-wrap", order=interp_order)
198193
indices = tuple(round(s) for s in positive_shift)
199194
splits_per_dim = [(slice(0, i), slice(i, None)) for i in indices]
200195
max_corr = -1.0
@@ -217,9 +212,7 @@ def _disambiguate_shift(reference_image, moving_image, shift):
217212
)
218213
return shift
219214
real_shift_acc = []
220-
for sl, pos_shift, neg_shift in zip(
221-
max_slice, positive_shift, negative_shift
222-
):
215+
for sl, pos_shift, neg_shift in zip(max_slice, positive_shift, negative_shift):
223216
real_shift_acc.append(pos_shift if sl.stop is None else neg_shift)
224217
if not subpixel:
225218
real_shift = tuple(map(int, real_shift_acc))
@@ -359,16 +352,12 @@ def phase_cross_correlation(
359352
# Initial shift estimate in upsampled grid
360353
# shift = cp.around(shift * upsample_factor) / upsample_factor
361354
upsample_factor = float(upsample_factor)
362-
shift = tuple(
363-
round(s * upsample_factor) / upsample_factor for s in shift
364-
)
355+
shift = tuple(round(s * upsample_factor) / upsample_factor for s in shift)
365356
upsampled_region_size = math.ceil(upsample_factor * 1.5)
366357
# Center of output array at dftshift + 1
367358
dftshift = float(upsampled_region_size // 2)
368359
# Matrix multiply DFT around the current shift estimate
369-
sample_region_offset = tuple(
370-
dftshift - s * upsample_factor for s in shift
371-
)
360+
sample_region_offset = tuple(dftshift - s * upsample_factor for s in shift)
372361
cross_correlation = _upsampled_dft(
373362
image_product.conj(),
374363
upsampled_region_size,
@@ -394,9 +383,7 @@ def phase_cross_correlation(
394383

395384
# If its only one row or column the shift along that dimension has no
396385
# effect. We set to zero.
397-
shift = tuple(
398-
s if axis_size != 1 else 0 for s, axis_size in zip(shift, shape)
399-
)
386+
shift = tuple(s if axis_size != 1 else 0 for s, axis_size in zip(shift, shape))
400387

401388
if disambiguate:
402389
if space.lower() != "real":
@@ -406,10 +393,7 @@ def phase_cross_correlation(
406393

407394
# Redirect user to masked_phase_cross_correlation if NaNs are observed
408395
if cp.isnan(CCmax) or cp.isnan(src_amp) or cp.isnan(target_amp):
409-
raise ValueError(
410-
"NaN values found, please remove NaNs from your "
411-
"input data"
412-
)
396+
raise ValueError("NaN values found, please remove NaNs from your " "input data")
413397

414398
return (
415399
shift,

tests/__init__.py

Whitespace-only changes.

tests/conftest.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,13 @@ def data_file(test_data_path):
6060
def ensure_clean_memory():
6161
cp.get_default_memory_pool().free_all_blocks()
6262
cp.get_default_pinned_memory_pool().free_all_blocks()
63+
cache = cp.fft.config.get_plan_cache()
64+
cache.clear()
6365
yield None
6466
cp.get_default_memory_pool().free_all_blocks()
6567
cp.get_default_pinned_memory_pool().free_all_blocks()
68+
cache = cp.fft.config.get_plan_cache()
69+
cache.clear()
6670

6771

6872
@pytest.fixture
@@ -135,3 +139,20 @@ def host_detector_x(data_file):
135139
@pytest.fixture
136140
def detector_x(host_detector_x, ensure_clean_memory):
137141
return cp.asarray(host_detector_x)
142+
143+
144+
class MaxMemoryHook(cp.cuda.MemoryHook):
145+
def __init__(self, initial=0):
146+
self.max_mem = initial
147+
self.current = initial
148+
149+
def malloc_postprocess(
150+
self, device_id: int, size: int, mem_size: int, mem_ptr: int, pmem_id: int
151+
):
152+
self.current += mem_size
153+
self.max_mem = max(self.max_mem, self.current)
154+
155+
def free_postprocess(
156+
self, device_id: int, mem_size: int, mem_ptr: int, pmem_id: int
157+
):
158+
self.current -= mem_size

0 commit comments

Comments
 (0)