2222
2323import numpy as np
2424from httomolibgpu import cupywrapper
25+ from httomolibgpu .memory_estimator_helpers import _DeviceMemStack
2526
2627cp = cupywrapper .cp
2728cupy_run = cupywrapper .cupy_run
3031
3132if cupy_run :
3233 from cupyx .scipy .fft import fft2 , ifft2 , fftshift
34+ from cupyx .scipy .fftpack import get_fft_plan
3335else :
3436 fft2 = Mock ()
3537 ifft2 = Mock ()
3638 fftshift = Mock ()
3739
3840from numpy import float32
39- from typing import Tuple
41+ from typing import Optional , Tuple
4042import math
4143
4244__all__ = [
@@ -54,6 +56,7 @@ def paganin_filter(
5456 distance : float = 1.0 ,
5557 energy : float = 53.0 ,
5658 ratio_delta_beta : float = 250 ,
59+ calc_peak_gpu_mem : bool = False ,
5760) -> cp .ndarray :
5861 """
5962 Perform single-material phase retrieval from flats/darks corrected tomographic measurements. For more detailed information, see :ref:`phase_contrast_module`.
@@ -71,30 +74,50 @@ def paganin_filter(
7174 Beam energy in keV.
7275 ratio_delta_beta : float
7376 The ratio of delta/beta, where delta is the phase shift and real part of the complex material refractive index and beta is the absorption.
77+ calc_peak_gpu_mem: bool
78+ Parameter to support memory estimation in HTTomo. Irrelevant to the method itself and can be ignored by user.
7479
7580 Returns
7681 -------
7782 cp.ndarray
7883 The 3D array of Paganin phase-filtered projection images.
7984 """
85+ mem_stack = _DeviceMemStack () if calc_peak_gpu_mem else None
8086 # Check the input data is valid
81- if tomo .ndim != 3 :
87+ if not mem_stack and tomo .ndim != 3 :
8288 raise ValueError (
8389 f"Invalid number of dimensions in data: { tomo .ndim } ,"
8490 " please provide a stack of 2D projections."
8591 )
86-
87- dz_orig , dy_orig , dx_orig = tomo .shape
92+ if mem_stack :
93+ mem_stack .malloc (np .prod (tomo ) * np .float32 ().itemsize )
94+ dz_orig , dy_orig , dx_orig = tomo .shape if not mem_stack else tomo
8895
8996 # Perform padding to the power of 2 as FFT is O(n*log(n)) complexity
9097 # TODO: adding other options of padding?
91- padded_tomo , pad_tup = _pad_projections_to_second_power (tomo )
98+ padded_tomo , pad_tup = _pad_projections_to_second_power (tomo , mem_stack )
9299
93- dz , dy , dx = padded_tomo .shape
100+ dz , dy , dx = padded_tomo .shape if not mem_stack else padded_tomo
94101
95102 # 3D FFT of tomo data
96- padded_tomo = cp .asarray (padded_tomo , dtype = cp .complex64 )
97- fft_tomo = fft2 (padded_tomo , axes = (- 2 , - 1 ), overwrite_x = True )
103+ if mem_stack :
104+ mem_stack .malloc (np .prod (padded_tomo ) * np .complex64 ().itemsize )
105+ mem_stack .free (np .prod (padded_tomo ) * np .float32 ().itemsize )
106+ fft_input = cp .empty (padded_tomo , dtype = cp .complex64 )
107+ else :
108+ padded_tomo = cp .asarray (padded_tomo , dtype = cp .complex64 )
109+ fft_input = padded_tomo
110+
111+ fft_plan = get_fft_plan (fft_input , axes = (- 2 , - 1 ))
112+ if mem_stack :
113+ mem_stack .malloc (fft_plan .work_area .mem .size )
114+ mem_stack .free (fft_plan .work_area .mem .size )
115+ else :
116+ with fft_plan :
117+ fft_tomo = fft2 (padded_tomo , axes = (- 2 , - 1 ), overwrite_x = True )
118+ del padded_tomo
119+ del fft_input
120+ del fft_plan
98121
99122 # calculate alpha constant
100123 alpha = _calculate_alpha (energy , distance / 1e-6 , ratio_delta_beta )
@@ -103,18 +126,56 @@ def paganin_filter(
103126 indx = _reciprocal_coord (pixel_size , dy )
104127 indy = _reciprocal_coord (pixel_size , dx )
105128
106- # Build Lorentzian-type filter
107- phase_filter = fftshift (
108- 1.0 / (1.0 + alpha * (cp .add .outer (cp .square (indx ), cp .square (indy ))))
109- )
129+ if mem_stack :
130+ mem_stack .malloc (indx .size * indx .dtype .itemsize ) # cp.asarray(indx)
131+ mem_stack .malloc (indx .size * indx .dtype .itemsize ) # cp.square
132+ mem_stack .free (indx .size * indx .dtype .itemsize ) # cp.asarray(indx)
133+ mem_stack .malloc (indy .size * indy .dtype .itemsize ) # cp.asarray(indy)
134+ mem_stack .malloc (indy .size * indy .dtype .itemsize ) # cp.square
135+ mem_stack .free (indy .size * indy .dtype .itemsize ) # cp.asarray(indy)
136+
137+ mem_stack .malloc (indx .size * indy .size * indx .dtype .itemsize ) # cp.add.outer
138+ mem_stack .free (indx .size * indx .dtype .itemsize ) # cp.square
139+ mem_stack .free (indy .size * indy .dtype .itemsize ) # cp.square
140+ mem_stack .malloc (indx .size * indy .size * indx .dtype .itemsize ) # phase_filter
141+ mem_stack .free (indx .size * indy .size * indx .dtype .itemsize ) # cp.add.outer
142+ mem_stack .free (indx .size * indy .size * indx .dtype .itemsize ) # phase_filter
143+
144+ else :
145+ # Build Lorentzian-type filter
146+ phase_filter = fftshift (
147+ 1.0
148+ / (
149+ 1.0
150+ + alpha
151+ * (
152+ cp .add .outer (
153+ cp .square (cp .asarray (indx )), cp .square (cp .asarray (indy ))
154+ )
155+ )
156+ )
157+ )
110158
111- phase_filter = phase_filter / phase_filter .max () # normalisation
159+ phase_filter = phase_filter / phase_filter .max () # normalisation
112160
113- # Filter projections
114- fft_tomo *= phase_filter
161+ # Filter projections
162+ fft_tomo *= phase_filter
163+ del phase_filter
115164
116165 # Apply filter and take inverse FFT
117- ifft_filtered_tomo = ifft2 (fft_tomo , axes = (- 2 , - 1 ), overwrite_x = True ).real
166+ ifft_input = (
167+ fft_tomo if not mem_stack else cp .empty (padded_tomo , dtype = cp .complex64 )
168+ )
169+ ifft_plan = get_fft_plan (ifft_input , axes = (- 2 , - 1 ))
170+ if mem_stack :
171+ mem_stack .malloc (ifft_plan .work_area .mem .size )
172+ mem_stack .free (ifft_plan .work_area .mem .size )
173+ else :
174+ with ifft_plan :
175+ ifft_filtered_tomo = ifft2 (fft_tomo , axes = (- 2 , - 1 ), overwrite_x = True ).real
176+ del fft_tomo
177+ del ifft_plan
178+ del ifft_input
118179
119180 # slicing indices for cropping
120181 slc_indices = (
@@ -123,8 +184,19 @@ def paganin_filter(
123184 slice (pad_tup [2 ][0 ], pad_tup [2 ][0 ] + dx_orig , 1 ),
124185 )
125186
187+ if mem_stack :
188+ mem_stack .malloc (np .prod (tomo ) * np .float32 ().itemsize ) # astype(cp.float32)
189+ mem_stack .free (
190+ np .prod (padded_tomo ) * np .complex64 ().itemsize
191+ ) # ifft_filtered_tomo
192+ mem_stack .malloc (
193+ np .prod (tomo ) * np .float32 ().itemsize
194+ ) # return _log_kernel(tomo)
195+ return mem_stack .highwater
196+
126197 # crop the padded filtered data:
127198 tomo = ifft_filtered_tomo [slc_indices ].astype (cp .float32 )
199+ del ifft_filtered_tomo
128200
129201 # taking the negative log
130202 _log_kernel = cp .ElementwiseKernel (
@@ -177,7 +249,7 @@ def _calculate_pad_size(datashape: tuple) -> list:
177249
178250
179251def _pad_projections_to_second_power (
180- tomo : cp .ndarray ,
252+ tomo : cp .ndarray , mem_stack : Optional [ _DeviceMemStack ]
181253) -> Tuple [cp .ndarray , Tuple [int , int ]]:
182254 """
183255 Performs padding of each projection to the next power of 2.
@@ -194,11 +266,17 @@ def _pad_projections_to_second_power(
194266 ndarray: padded 3d projection data
195267 tuple: a tuple with padding dimensions
196268 """
197- full_shape_tomo = cp .shape (tomo )
269+ full_shape_tomo = cp .shape (tomo ) if not mem_stack else tomo
198270
199271 pad_list = _calculate_pad_size (full_shape_tomo )
200272
201- padded_tomo = cp .pad (tomo , tuple (pad_list ), "edge" )
273+ if mem_stack :
274+ padded_tomo = [
275+ sh + pad [0 ] + pad [1 ] for sh , pad in zip (full_shape_tomo , pad_list )
276+ ]
277+ mem_stack .malloc (np .prod (padded_tomo ) * np .float32 ().itemsize )
278+ else :
279+ padded_tomo = cp .pad (tomo , tuple (pad_list ), "edge" )
202280
203281 return padded_tomo , tuple (pad_list )
204282
@@ -209,7 +287,7 @@ def _wavelength_micron(energy: float) -> float:
209287 return 2 * math .pi * PLANCK_CONSTANT * SPEED_OF_LIGHT / energy
210288
211289
212- def _reciprocal_coord (pixel_size : float , num_grid : int ) -> cp .ndarray :
290+ def _reciprocal_coord (pixel_size : float , num_grid : int ) -> np .ndarray :
213291 """
214292 Calculate reciprocal grid coordinates for a given pixel size
215293 and discretization.
@@ -227,7 +305,7 @@ def _reciprocal_coord(pixel_size: float, num_grid: int) -> cp.ndarray:
227305 Grid coordinates.
228306 """
229307 n = num_grid - 1
230- rc = cp .arange (- n , num_grid , 2 , dtype = cp .float32 )
308+ rc = np .arange (- n , num_grid , 2 , dtype = cp .float32 )
231309 rc *= 2 * math .pi / (n * pixel_size )
232310 return rc
233311
@@ -238,6 +316,7 @@ def paganin_filter_savu_legacy(
238316 distance : float = 1.0 ,
239317 energy : float = 53.0 ,
240318 ratio_delta_beta : float = 250 ,
319+ calc_peak_gpu_mem : bool = False ,
241320) -> cp .ndarray :
242321 """
243322 Perform single-material phase retrieval from flats/darks corrected tomographic measurements. For more detailed information, see :ref:`phase_contrast_module`.
@@ -256,11 +335,20 @@ def paganin_filter_savu_legacy(
256335 Beam energy in keV.
257336 ratio_delta_beta : float
258337 The ratio of delta/beta, where delta is the phase shift and real part of the complex material refractive index and beta is the absorption.
338+ calc_peak_gpu_mem: bool
339+ Parameter to support memory estimation in HTTomo. Irrelevant to the method itself and can be ignored by user.
259340
260341 Returns
261342 -------
262343 cp.ndarray
263344 The 3D array of Paganin phase-filtered projection images.
264345 """
265346
266- return paganin_filter (tomo , pixel_size , distance , energy , ratio_delta_beta / 4 )
347+ return paganin_filter (
348+ tomo ,
349+ pixel_size ,
350+ distance ,
351+ energy ,
352+ ratio_delta_beta / 4 ,
353+ calc_peak_gpu_mem = calc_peak_gpu_mem ,
354+ )
0 commit comments