4949]
5050
5151
52- def raven_filter_savu (
52+ def raven_filter (
5353 data : cp .ndarray ,
54- kernel_size : int = 3 ,
55- pad_y : int = 100 ,
56- pad_x : int = 100 ,
54+ pad_y : int = 20 ,
55+ pad_x : int = 20 ,
5756 pad_method : str = "edge" ,
57+ uvalue : int = 20 ,
58+ nvalue : int = 4 ,
59+ vvalue : int = 2 ,
5860) -> cp .ndarray :
5961 """
6062 Applies raven filter to a 3D CuPy array. For more detailed information, see :ref:`method_raven_filter`.
@@ -63,8 +65,6 @@ def raven_filter_savu(
6365 ----------
6466 data : cp.ndarray
6567 Input CuPy 3D array either float32 or uint16 data type.
66- kernel_size : int, optional
67- The size of the filter's kernel (a diameter).
6868
6969 pad_y : int, optional
7070 Pad the top and bottom of projections.
@@ -75,10 +75,19 @@ def raven_filter_savu(
7575 pad_method : str, optional
7676 Numpy pad method to use.
7777
78+ uvalue : int, optional
79+ The shape of filter.
80+
81+ nvalue : int, optional
82+ The shape of filter.
83+
84+ vvalue : int, optional
85+ The number of rows to be applied the filter
86+
7887 Returns
7988 -------
8089 ndarray
81- Raven filtered 3D CuPy array either float32 or uint16 data type.
90+ Raven filtered 3D CuPy array in float32 data type.
8291
8392 Raises
8493 ------
@@ -87,55 +96,40 @@ def raven_filter_savu(
8796 """
8897 input_type = data .dtype
8998
90- if input_type not in ["float32" , "uint16" ]:
91- raise ValueError ("The input data should be either float32 or uint16 data type" )
92-
93- if data .ndim == 3 :
94- if 0 in data .shape :
95- raise ValueError ("The length of one of dimensions is equal to zero" )
96- else :
97- raise ValueError ("The input array must be a 3D array" )
98-
99- if kernel_size not in [3 , 5 , 7 , 9 , 11 , 13 ]:
100- raise ValueError ("Please select a correct kernel size: 3, 5, 7, 9, 11, 13" )
101-
99+ if input_type not in ["float32" ]:
100+ raise ValueError ("The input data should be float32" )
102101
103- dz_orig , dy_orig , dx_orig = data .shape
102+ # if data.ndim != 3:
103+ # raise ValueError("only 3D data is supported")
104104
105- padded_data , pad_tup = cp . pad ( data , fftpad , "edge" )
106- dz , dy , dx = padded_data . shape
105+ # Padding of the data
106+ padded_data = cp . pad ( data , (( pad_y , pad_y ), ( pad_x , pad_x )), mode = pad_method )
107107
108- # 3D FFT of data
109- padded_data = cp .pad (data , ((0 , 0 ), (pad_y , pad_y ), (pad_x , pad_x )), mode = pad_method )
108+ # FFT and shift of data
110109 fft_data = fft2 (padded_data , axes = (- 2 , - 1 ), overwrite_x = True )
111110 fft_data_shifted = fftshift (fft_data )
112111
113112 # Setup various values for the filter
114- _ , height , width = data .shape
113+ height , width = data .shape
115114
116115 height1 = height + 2 * pad_y
117116 width1 = width + 2 * pad_x
118117
119- # raven
120- kernel_args = "raven_general_kernel3d<{0}, {1}>" .format (
121- "float" if input_type == "float32" else "unsigned short" , kernel_size
122- )
123- block_x = 128
124118 # setting grid/block parameters
119+ block_x = 128
125120 block_dims = (block_x , 1 , 1 )
126- grid_x = (dx + block_x - 1 ) // block_x
127- grid_y = dy
128- grid_z = dz
129- grid_dims = (grid_x , grid_y , grid_z )
130- params = (fft_data_shifted , dz , dy , dx )
131-
132- raven_module = load_cuda_module ("raven_kernel" , name_expressions = [kernel_args ])
133- raven_filt = raven_module .get_function (kernel_args )
134-
121+ grid_x = (width1 + block_x - 1 ) // block_x
122+ grid_y = height1
123+ grid_dims = (grid_x , grid_y , 1 )
124+ params = (fft_data_shifted , fft_data , width1 , height1 , uvalue , nvalue , vvalue )
125+
126+ raven_module = load_cuda_module ("raven_filter" )
127+ raven_filt = raven_module .get_function ("raven_filter" )
128+
135129 raven_filt (grid_dims , block_dims , params )
136-
137- fft_data = fftshift ( fft_data_shifted )
138-
139- data = ifft2 (fft_data , axes = (- 2 , - 1 ), overwrite_x = True , norm = "forward" )
130+
131+ # raven_fil already doing ifftshifting
132+ # fft_data = ifftshift(fft_data_shifted)
133+ data = ifft2 (fft_data , axes = (- 2 , - 1 ), overwrite_x = True )
140134
141135 return data
0 commit comments