2222
2323import numpy as np
2424from httomolibgpu import cupywrapper
25+ from httomolibgpu .memory_estimator_helpers import _DeviceMemStack
2526
2627cp = cupywrapper .cp
2728cupy_run = cupywrapper .cupy_run
3031
3132if cupy_run :
3233 from cupyx .scipy .fft import fft2 , ifft2 , fftshift
34+ from cupyx .scipy .fftpack import get_fft_plan
3335else :
3436 fft2 = Mock ()
3537 ifft2 = Mock ()
3638 fftshift = Mock ()
3739
3840from numpy import float32
39- from typing import Tuple
41+ from typing import Optional , Tuple
4042import math
4143
4244__all__ = [
@@ -54,6 +56,7 @@ def paganin_filter(
5456 distance : float = 1.0 ,
5557 energy : float = 53.0 ,
5658 ratio_delta_beta : float = 250 ,
59+ calc_peak_gpu_mem : bool = False ,
5760) -> cp .ndarray :
5861 """
5962 Perform single-material phase retrieval from flats/darks corrected tomographic measurements. For more detailed information, see :ref:`phase_contrast_module`.
@@ -77,24 +80,42 @@ def paganin_filter(
7780 cp.ndarray
7881 The 3D array of Paganin phase-filtered projection images.
7982 """
83+ mem_stack = _DeviceMemStack () if calc_peak_gpu_mem else None
8084 # Check the input data is valid
81- if tomo .ndim != 3 :
85+ if not mem_stack and tomo .ndim != 3 :
8286 raise ValueError (
8387 f"Invalid number of dimensions in data: { tomo .ndim } ,"
8488 " please provide a stack of 2D projections."
8589 )
86-
87- dz_orig , dy_orig , dx_orig = tomo .shape
90+ if mem_stack :
91+ mem_stack .malloc (np .prod (tomo ) * np .float32 ().itemsize )
92+ dz_orig , dy_orig , dx_orig = tomo .shape if not mem_stack else tomo
8893
8994 # Perform padding to the power of 2 as FFT is O(n*log(n)) complexity
9095 # TODO: adding other options of padding?
91- padded_tomo , pad_tup = _pad_projections_to_second_power (tomo )
96+ padded_tomo , pad_tup = _pad_projections_to_second_power (tomo , mem_stack )
9297
93- dz , dy , dx = padded_tomo .shape
98+ dz , dy , dx = padded_tomo .shape if not mem_stack else padded_tomo
9499
95100 # 3D FFT of tomo data
96- padded_tomo = cp .asarray (padded_tomo , dtype = cp .complex64 )
97- fft_tomo = fft2 (padded_tomo , axes = (- 2 , - 1 ), overwrite_x = True )
101+ if mem_stack :
102+ mem_stack .malloc (np .prod (padded_tomo ) * np .complex64 ().itemsize )
103+ mem_stack .free (np .prod (padded_tomo ) * np .float32 ().itemsize )
104+ fft_input = cp .empty (padded_tomo , dtype = cp .complex64 )
105+ else :
106+ padded_tomo = cp .asarray (padded_tomo , dtype = cp .complex64 )
107+ fft_input = padded_tomo
108+
109+ fft_plan = get_fft_plan (fft_input , axes = (- 2 , - 1 ))
110+ if mem_stack :
111+ mem_stack .malloc (fft_plan .work_area .mem .size )
112+ mem_stack .free (fft_plan .work_area .mem .size )
113+ else :
114+ with fft_plan :
115+ fft_tomo = fft2 (padded_tomo , axes = (- 2 , - 1 ), overwrite_x = True )
116+ del padded_tomo
117+ del fft_input
118+ del fft_plan
98119
99120 # calculate alpha constant
100121 alpha = _calculate_alpha (energy , distance / 1e-6 , ratio_delta_beta )
@@ -103,18 +124,41 @@ def paganin_filter(
103124 indx = _reciprocal_coord (pixel_size , dy )
104125 indy = _reciprocal_coord (pixel_size , dx )
105126
106- # Build Lorentzian-type filter
107- phase_filter = fftshift (
108- 1.0 / (1.0 + alpha * (cp .add .outer (cp .square (indx ), cp .square (indy ))))
109- )
127+ if mem_stack :
128+ mem_stack .malloc (indx .size * indx .dtype .itemsize ) # cp.asarray(indx)
129+ mem_stack .malloc (indx .size ** 2 * indx .dtype .itemsize ) # cp.square
130+ mem_stack .malloc (indy .size * indy .dtype .itemsize ) # cp.asarray(indy)
131+ mem_stack .malloc (indy .size ** 2 * indy .dtype .itemsize ) # cp.square
132+
133+ mem_stack .malloc (indx .size ** 2 * indx .dtype .itemsize ) # phase_filter
134+
135+ mem_stack .free (indx .size * indx .dtype .itemsize )
136+ mem_stack .free (indx .size ** 2 * indx .dtype .itemsize )
137+ mem_stack .free (indy .size * indy .dtype .itemsize )
138+ mem_stack .free (indy .size ** 2 * indy .dtype .itemsize )
139+ else :
140+ # Build Lorentzian-type filter
141+ phase_filter = fftshift (
142+ 1.0 / (1.0 + alpha * (cp .add .outer (cp .square (cp .asarray (indx )), cp .square (cp .asarray (indy )))))
143+ )
110144
111- phase_filter = phase_filter / phase_filter .max () # normalisation
145+ phase_filter = phase_filter / phase_filter .max () # normalisation
112146
113- # Filter projections
114- fft_tomo *= phase_filter
147+ # Filter projections
148+ fft_tomo *= phase_filter
115149
116150 # Apply filter and take inverse FFT
117- ifft_filtered_tomo = ifft2 (fft_tomo , axes = (- 2 , - 1 ), overwrite_x = True ).real
151+ ifft_input = fft_tomo if not mem_stack else cp .empty (padded_tomo , dtype = cp .complex64 )
152+ ifft_plan = get_fft_plan (ifft_input , axes = (- 2 , - 1 ))
153+ if mem_stack :
154+ mem_stack .malloc (ifft_plan .work_area .mem .size )
155+ mem_stack .free (ifft_plan .work_area .mem .size )
156+ else :
157+ with ifft_plan :
158+ ifft_filtered_tomo = ifft2 (fft_tomo , axes = (- 2 , - 1 ), overwrite_x = True ).real
159+ del fft_tomo
160+ del ifft_plan
161+ del ifft_input
118162
119163 # slicing indices for cropping
120164 slc_indices = (
@@ -123,8 +167,15 @@ def paganin_filter(
123167 slice (pad_tup [2 ][0 ], pad_tup [2 ][0 ] + dx_orig , 1 ),
124168 )
125169
170+ if mem_stack :
171+ mem_stack .malloc (np .prod (tomo ) * np .float32 ().itemsize ) # astype(cp.float32)
172+ mem_stack .free (np .prod (padded_tomo ) * np .complex64 ().itemsize ) # ifft_filtered_tomo
173+ mem_stack .malloc (np .prod (tomo ) * np .float32 ().itemsize ) # return _log_kernel(tomo)
174+ return mem_stack .highwater
175+
126176 # crop the padded filtered data:
127177 tomo = ifft_filtered_tomo [slc_indices ].astype (cp .float32 )
178+ del ifft_filtered_tomo
128179
129180 # taking the negative log
130181 _log_kernel = cp .ElementwiseKernel (
@@ -178,6 +229,7 @@ def _calculate_pad_size(datashape: tuple) -> list:
178229
179230def _pad_projections_to_second_power (
180231 tomo : cp .ndarray ,
232+ mem_stack : Optional [_DeviceMemStack ]
181233) -> Tuple [cp .ndarray , Tuple [int , int ]]:
182234 """
183235 Performs padding of each projection to the next power of 2.
@@ -194,11 +246,15 @@ def _pad_projections_to_second_power(
194246 ndarray: padded 3d projection data
195247 tuple: a tuple with padding dimensions
196248 """
197- full_shape_tomo = cp .shape (tomo )
249+ full_shape_tomo = cp .shape (tomo ) if not mem_stack else tomo
198250
199251 pad_list = _calculate_pad_size (full_shape_tomo )
200252
201- padded_tomo = cp .pad (tomo , tuple (pad_list ), "edge" )
253+ if mem_stack :
254+ padded_tomo = [sh + pad [0 ] + pad [1 ] for sh , pad in zip (full_shape_tomo , pad_list )]
255+ mem_stack .malloc (np .prod (padded_tomo ) * np .float32 ().itemsize )
256+ else :
257+ padded_tomo = cp .pad (tomo , tuple (pad_list ), "edge" )
202258
203259 return padded_tomo , tuple (pad_list )
204260
@@ -209,7 +265,7 @@ def _wavelength_micron(energy: float) -> float:
209265 return 2 * math .pi * PLANCK_CONSTANT * SPEED_OF_LIGHT / energy
210266
211267
212- def _reciprocal_coord (pixel_size : float , num_grid : int ) -> cp .ndarray :
268+ def _reciprocal_coord (pixel_size : float , num_grid : int ) -> np .ndarray :
213269 """
214270 Calculate reciprocal grid coordinates for a given pixel size
215271 and discretization.
@@ -227,7 +283,7 @@ def _reciprocal_coord(pixel_size: float, num_grid: int) -> cp.ndarray:
227283 Grid coordinates.
228284 """
229285 n = num_grid - 1
230- rc = cp .arange (- n , num_grid , 2 , dtype = cp .float32 )
286+ rc = np .arange (- n , num_grid , 2 , dtype = cp .float32 )
231287 rc *= 2 * math .pi / (n * pixel_size )
232288 return rc
233289
0 commit comments