3535
3636if cupy_run :
3737 from httomolibgpu .cuda_kernels import load_cuda_module
38+ from cupyx .scipy .fft import fft2 , ifft2 , fftshift , ifftshift
3839else :
3940 load_cuda_module = Mock ()
41+ fft2 = Mock ()
42+ ifft2 = Mock ()
43+ fftshift = Mock ()
44+ ifftshift = Mock ()
4045
4146
4247__all__ = [
4348 "raven_filter" ,
4449]
4550
4651
47- def raven_filter (
52+ def raven_filter_savu (
4853 data : cp .ndarray ,
4954 kernel_size : int = 3 ,
50- dif : float = 0.0 ,
55+ pad_y : int = 100 ,
56+ pad_x : int = 100 ,
57+ pad_method : str = "edge" ,
5158) -> cp .ndarray :
5259 """
5360 Applies raven filter to a 3D CuPy array. For more detailed information, see :ref:`method_raven_filter`.
@@ -58,14 +65,20 @@ def raven_filter(
5865 Input CuPy 3D array either float32 or uint16 data type.
5966 kernel_size : int, optional
6067 The size of the filter's kernel (a diameter).
61- dif : float, optional
62- Expected difference value between outlier value and the
63- median value of the array, leave equal to 0 for classical median.
68+
69+ pad_y : int, optional
70+ Pad the top and bottom of projections.
71+
72+ pad_x : int, optional
73+ Pad the left and right of projections.
74+
75+ pad_method : str, optional
76+ Numpy pad method to use.
6477
6578 Returns
6679 -------
6780 ndarray
68- Median filtered 3D CuPy array either float32 or uint16 data type.
81+ Raven filtered 3D CuPy array either float32 or uint16 data type.
6982
7083 Raises
7184 ------
@@ -86,11 +99,25 @@ def raven_filter(
8699 if kernel_size not in [3 , 5 , 7 , 9 , 11 , 13 ]:
87100 raise ValueError ("Please select a correct kernel size: 3, 5, 7, 9, 11, 13" )
88101
89- dz , dy , dx = data .shape
90- output = cp .copy (data , order = "C" )
91102
92- # 3d median or dezinger
93- kernel_args = "median_general_kernel3d<{0}, {1}>" .format (
103+ dz_orig , dy_orig , dx_orig = data .shape
104+
105+ padded_data , pad_tup = cp .pad (data , fftpad , "edge" )
106+ dz , dy , dx = padded_data .shape
107+
108+ # 3D FFT of data
109+ padded_data = cp .pad (data , ((0 , 0 ), (pad_y , pad_y ), (pad_x , pad_x )), mode = pad_method )
110+ fft_data = fft2 (padded_data , axes = (- 2 , - 1 ), overwrite_x = True )
111+ fft_data_shifted = fftshift (fft_data )
112+
113+ # Setup various values for the filter
114+ _ , height , width = data .shape
115+
116+ height1 = height + 2 * pad_y
117+ width1 = width + 2 * pad_x
118+
119+ # raven
120+ kernel_args = "raven_general_kernel3d<{0}, {1}>" .format (
94121 "float" if input_type == "float32" else "unsigned short" , kernel_size
95122 )
96123 block_x = 128
@@ -100,11 +127,15 @@ def raven_filter(
100127 grid_y = dy
101128 grid_z = dz
102129 grid_dims = (grid_x , grid_y , grid_z )
103- params = (data , output , cp .float32 (dif ), dz , dy , dx )
130+ params = (fft_data_shifted , dz , dy , dx )
131+
132+ raven_module = load_cuda_module ("raven_kernel" , name_expressions = [kernel_args ])
133+ raven_filt = raven_module .get_function (kernel_args )
134+
135+ raven_filt (grid_dims , block_dims , params )
104136
105- median_module = load_cuda_module ("raven_filter" , name_expressions = [kernel_args ])
106- median_filt = median_module .get_function (kernel_args )
137+ fft_data = fftshift (fft_data_shifted )
107138
108- median_filt ( grid_dims , block_dims , params )
139+ data = ifft2 ( fft_data , axes = ( - 2 , - 1 ), overwrite_x = True , norm = "forward" )
109140
110- return output
141+ return data
0 commit comments