2323"""
2424
2525import numpy as np
26- from httomolibgpu import cupywrapper
2726from typing import Union
2827
28+ from httomolibgpu import cupywrapper
29+
2930cp = cupywrapper .cp
3031nvtx = cupywrapper .nvtx
3132
3233from numpy import float32
34+ from httomolibgpu .cuda_kernels import load_cuda_module
3335
3436__all__ = [
3537 "median_filter" ,
4042def median_filter (
4143 data : cp .ndarray ,
4244 kernel_size : int = 3 ,
43- axis : Union [int , None ] = 0 ,
4445 dif : float = 0.0 ,
4546) -> cp .ndarray :
4647 """
47- Apply 2D or 3D median filter to a 3D CuPy array. For more detailed information, see :ref:`method_median_filter`.
48+ Apply 3D median filter to a 3D CuPy array. For more detailed information, see :ref:`method_median_filter`.
4849
4950 Parameters
5051 ----------
5152 data : cp.ndarray
5253 Input CuPy 3D array either float32 or uint16 data type.
5354 kernel_size : int, optional
5455 The size of the filter's kernel (a diameter).
55- axis: int or None, optional:
56- Axis along which the 2D filter kernel should be applied. If set to None, then the kernel is 3D.
5756 dif : float, optional
5857 Expected difference value between outlier value and the
5958 median value of the array, leave equal to 0 for classical median.
@@ -69,7 +68,7 @@ def median_filter(
6968 If the input array is not three dimensional.
7069 """
7170 if cupywrapper .cupy_run :
72- return __median_filter (data , kernel_size , axis , dif )
71+ return __median_filter (data , kernel_size , dif )
7372 else :
7473 print ("median_filter won't be executed because CuPy is not installed" )
7574 return data
@@ -79,18 +78,8 @@ def median_filter(
7978def __median_filter (
8079 data : cp .ndarray ,
8180 kernel_size : int = 3 ,
82- axis : Union [int , None ] = 0 ,
8381 dif : float = 0.0 ,
8482) -> cp .ndarray :
85- try :
86- from cucim .skimage .filters import median
87- from cucim .skimage .morphology import disk
88- except ImportError :
89- print (
90- "Cucim library of Rapidsai is a required dependency for median_filter and remove_outlier modules, please install"
91- )
92- from httomolibgpu .cuda_kernels import load_cuda_module
93-
9483 input_type = data .dtype
9584
9685 if input_type not in ["float32" , "uint16" ]:
@@ -105,65 +94,32 @@ def __median_filter(
10594 if kernel_size not in [3 , 5 , 7 , 9 , 11 , 13 ]:
10695 raise ValueError ("Please select a correct kernel size: 3, 5, 7, 9, 11, 13" )
10796
108- if axis not in [0 , 1 , 2 , None ]:
109- raise ValueError ("The axis should be 0,1,2 or None for full 3d processing" )
110-
11197 dz , dy , dx = data .shape
11298 output = cp .copy (data , order = "C" )
11399
114- if axis == 0 :
115- for j in range (dz ):
116- median (data [j , :, :], footprint = disk (kernel_size // 2 ), out = output [j , :, :])
117- elif axis == 1 :
118- for j in range (dy ):
119- median (data [:, j , :], footprint = disk (kernel_size // 2 ), out = output [:, j , :])
120- elif axis == 2 :
121- for j in range (dx ):
122- median (data [:, :, j ], footprint = disk (kernel_size // 2 ), out = output [:, :, j ])
123- else :
124- # 3d median or dezinger
125- kernel_args = "median_general_kernel3d<{0}, {1}>" .format (
126- "float" if input_type == "float32" else "unsigned short" , kernel_size
127- )
128- block_x = 128
129- # setting grid/block parameters
130- block_dims = (block_x , 1 , 1 )
131- grid_x = (dx + block_x - 1 ) // block_x
132- grid_y = dy
133- grid_z = dz
134- grid_dims = (grid_x , grid_y , grid_z )
135- params = (data , output , cp .float32 (dif ), dz , dy , dx )
136-
137- median_module = load_cuda_module (
138- "median_kernel" , name_expressions = [kernel_args ]
139- )
140- median_filt = median_module .get_function (kernel_args )
141-
142- median_filt (grid_dims , block_dims , params )
143-
144- if axis is not None and dif > 0 :
145- # 2d dezingering enabled
146- kernel_name = "thresholding"
147- kernel = r"""
148- float dif_curr = abs(float(data) - float(output));
149- if (dif_curr > dif) {
150- output = data;
151- }
152- """
153- thresholding_kernel = cp .ElementwiseKernel (
154- "T data, raw float32 dif" ,
155- "T output" ,
156- kernel ,
157- kernel_name ,
158- options = ("-std=c++11" ,),
159- no_return = True ,
160- )
161- thresholding_kernel (data , float32 (dif ), output )
100+ # 3d median or dezinger
101+ kernel_args = "median_general_kernel3d<{0}, {1}>" .format (
102+ "float" if input_type == "float32" else "unsigned short" , kernel_size
103+ )
104+ block_x = 128
105+ # setting grid/block parameters
106+ block_dims = (block_x , 1 , 1 )
107+ grid_x = (dx + block_x - 1 ) // block_x
108+ grid_y = dy
109+ grid_z = dz
110+ grid_dims = (grid_x , grid_y , grid_z )
111+ params = (data , output , cp .float32 (dif ), dz , dy , dx )
112+
113+ median_module = load_cuda_module ("median_kernel" , name_expressions = [kernel_args ])
114+ median_filt = median_module .get_function (kernel_args )
115+
116+ median_filt (grid_dims , block_dims , params )
117+
162118 return output
163119
164120
165121def remove_outlier (
166- data : cp .ndarray , kernel_size : int = 3 , axis : Union [ int , None ] = 0 , dif : float = 0.1
122+ data : cp .ndarray , kernel_size : int = 3 , dif : float = 0.1
167123) -> cp .ndarray :
168124 """Selectively applies 3D median filter to a 3D CuPy array to remove outliers. Also called a dezinger.
169125 For more detailed information, see :ref:`method_outlier_removal`.
@@ -174,8 +130,6 @@ def remove_outlier(
174130 Input CuPy 3D array either float32 or uint16 data type.
175131 kernel_size : int, optional
176132 The size of the filter's kernel (a diameter).
177- axis: int or None, optional:
178- Axis along which the 2D filter kernel should be applied. If set to None, then the kernel is 3D.
179133 dif : float, optional
180134 Expected difference value between outlier value and the
181135 median value of the array.
@@ -195,7 +149,7 @@ def remove_outlier(
195149 raise ValueError ("Threshold value (dif) must be positive and nonzero." )
196150
197151 if cupywrapper .cupy_run :
198- return __median_filter (data , kernel_size , axis , dif )
152+ return __median_filter (data , kernel_size , dif )
199153 else :
200154 print ("remove_outlier won't be executed because CuPy is not installed" )
201155 return data
0 commit comments