Skip to content

Commit 51e05ba

Browse files
committed
removes 2d median/dezingering using cucim
1 parent 52deaaa commit 51e05ba

File tree

2 files changed

+32
-120
lines changed

2 files changed

+32
-120
lines changed

httomolibgpu/misc/corr.py

Lines changed: 25 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323
"""
2424

2525
import numpy as np
26-
from httomolibgpu import cupywrapper
2726
from typing import Union
2827

28+
from httomolibgpu import cupywrapper
29+
2930
cp = cupywrapper.cp
3031
nvtx = cupywrapper.nvtx
3132

3233
from numpy import float32
34+
from httomolibgpu.cuda_kernels import load_cuda_module
3335

3436
__all__ = [
3537
"median_filter",
@@ -40,20 +42,17 @@
4042
def median_filter(
4143
data: cp.ndarray,
4244
kernel_size: int = 3,
43-
axis: Union[int, None] = 0,
4445
dif: float = 0.0,
4546
) -> cp.ndarray:
4647
"""
47-
Apply 2D or 3D median filter to a 3D CuPy array. For more detailed information, see :ref:`method_median_filter`.
48+
Apply 3D median filter to a 3D CuPy array. For more detailed information, see :ref:`method_median_filter`.
4849
4950
Parameters
5051
----------
5152
data : cp.ndarray
5253
Input CuPy 3D array either float32 or uint16 data type.
5354
kernel_size : int, optional
5455
The size of the filter's kernel (a diameter).
55-
axis: int or None, optional:
56-
Axis along which the 2D filter kernel should be applied. If set to None, then the kernel is 3D.
5756
dif : float, optional
5857
Expected difference value between outlier value and the
5958
median value of the array, leave equal to 0 for classical median.
@@ -69,7 +68,7 @@ def median_filter(
6968
If the input array is not three dimensional.
7069
"""
7170
if cupywrapper.cupy_run:
72-
return __median_filter(data, kernel_size, axis, dif)
71+
return __median_filter(data, kernel_size, dif)
7372
else:
7473
print("median_filter won't be executed because CuPy is not installed")
7574
return data
@@ -79,18 +78,8 @@ def median_filter(
7978
def __median_filter(
8079
data: cp.ndarray,
8180
kernel_size: int = 3,
82-
axis: Union[int, None] = 0,
8381
dif: float = 0.0,
8482
) -> cp.ndarray:
85-
try:
86-
from cucim.skimage.filters import median
87-
from cucim.skimage.morphology import disk
88-
except ImportError:
89-
print(
90-
"Cucim library of Rapidsai is a required dependency for median_filter and remove_outlier modules, please install"
91-
)
92-
from httomolibgpu.cuda_kernels import load_cuda_module
93-
9483
input_type = data.dtype
9584

9685
if input_type not in ["float32", "uint16"]:
@@ -105,65 +94,32 @@ def __median_filter(
10594
if kernel_size not in [3, 5, 7, 9, 11, 13]:
10695
raise ValueError("Please select a correct kernel size: 3, 5, 7, 9, 11, 13")
10796

108-
if axis not in [0, 1, 2, None]:
109-
raise ValueError("The axis should be 0,1,2 or None for full 3d processing")
110-
11197
dz, dy, dx = data.shape
11298
output = cp.copy(data, order="C")
11399

114-
if axis == 0:
115-
for j in range(dz):
116-
median(data[j, :, :], footprint=disk(kernel_size // 2), out=output[j, :, :])
117-
elif axis == 1:
118-
for j in range(dy):
119-
median(data[:, j, :], footprint=disk(kernel_size // 2), out=output[:, j, :])
120-
elif axis == 2:
121-
for j in range(dx):
122-
median(data[:, :, j], footprint=disk(kernel_size // 2), out=output[:, :, j])
123-
else:
124-
# 3d median or dezinger
125-
kernel_args = "median_general_kernel3d<{0}, {1}>".format(
126-
"float" if input_type == "float32" else "unsigned short", kernel_size
127-
)
128-
block_x = 128
129-
# setting grid/block parameters
130-
block_dims = (block_x, 1, 1)
131-
grid_x = (dx + block_x - 1) // block_x
132-
grid_y = dy
133-
grid_z = dz
134-
grid_dims = (grid_x, grid_y, grid_z)
135-
params = (data, output, cp.float32(dif), dz, dy, dx)
136-
137-
median_module = load_cuda_module(
138-
"median_kernel", name_expressions=[kernel_args]
139-
)
140-
median_filt = median_module.get_function(kernel_args)
141-
142-
median_filt(grid_dims, block_dims, params)
143-
144-
if axis is not None and dif > 0:
145-
# 2d dezingering enabled
146-
kernel_name = "thresholding"
147-
kernel = r"""
148-
float dif_curr = abs(float(data) - float(output));
149-
if (dif_curr > dif) {
150-
output = data;
151-
}
152-
"""
153-
thresholding_kernel = cp.ElementwiseKernel(
154-
"T data, raw float32 dif",
155-
"T output",
156-
kernel,
157-
kernel_name,
158-
options=("-std=c++11",),
159-
no_return=True,
160-
)
161-
thresholding_kernel(data, float32(dif), output)
100+
# 3d median or dezinger
101+
kernel_args = "median_general_kernel3d<{0}, {1}>".format(
102+
"float" if input_type == "float32" else "unsigned short", kernel_size
103+
)
104+
block_x = 128
105+
# setting grid/block parameters
106+
block_dims = (block_x, 1, 1)
107+
grid_x = (dx + block_x - 1) // block_x
108+
grid_y = dy
109+
grid_z = dz
110+
grid_dims = (grid_x, grid_y, grid_z)
111+
params = (data, output, cp.float32(dif), dz, dy, dx)
112+
113+
median_module = load_cuda_module("median_kernel", name_expressions=[kernel_args])
114+
median_filt = median_module.get_function(kernel_args)
115+
116+
median_filt(grid_dims, block_dims, params)
117+
162118
return output
163119

164120

165121
def remove_outlier(
166-
data: cp.ndarray, kernel_size: int = 3, axis: Union[int, None] = 0, dif: float = 0.1
122+
data: cp.ndarray, kernel_size: int = 3, dif: float = 0.1
167123
) -> cp.ndarray:
168124
"""Selectively applies 3D median filter to a 3D CuPy array to remove outliers. Also called a dezinger.
169125
For more detailed information, see :ref:`method_outlier_removal`.
@@ -174,8 +130,6 @@ def remove_outlier(
174130
Input CuPy 3D array either float32 or uint16 data type.
175131
kernel_size : int, optional
176132
The size of the filter's kernel (a diameter).
177-
axis: int or None, optional:
178-
Axis along which the 2D filter kernel should be applied. If set to None, then the kernel is 3D.
179133
dif : float, optional
180134
Expected difference value between outlier value and the
181135
median value of the array.
@@ -195,7 +149,7 @@ def remove_outlier(
195149
raise ValueError("Threshold value (dif) must be positive and nonzero.")
196150

197151
if cupywrapper.cupy_run:
198-
return __median_filter(data, kernel_size, axis, dif)
152+
return __median_filter(data, kernel_size, dif)
199153
else:
200154
print("remove_outlier won't be executed because CuPy is not installed")
201155
return data

tests/test_misc/test_corr.py

Lines changed: 7 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,14 @@ def test_median_filter3d_vs_scipy_on_arange(ensure_clean_memory):
1919
mat = np.arange(4 * 5 * 6).reshape(4, 5, 6)
2020
assert_equal(
2121
scipy.ndimage.median_filter(np.float32(mat), size=3),
22-
median_filter(
23-
cp.asarray(mat, dtype=cp.float32), kernel_size=3, axis=None
24-
).get(),
22+
median_filter(cp.asarray(mat, dtype=cp.float32), kernel_size=3).get(),
2523
)
2624

2725

2826
def test_median_filter3d_vs_scipy(host_data, ensure_clean_memory):
2927
assert_equal(
3028
scipy.ndimage.median_filter(np.float32(host_data), size=3),
31-
median_filter(
32-
cp.asarray(host_data, dtype=cp.float32), kernel_size=3, axis=None
33-
).get(),
29+
median_filter(cp.asarray(host_data, dtype=cp.float32), kernel_size=3).get(),
3430
)
3531

3632

@@ -43,7 +39,6 @@ def test_median_filter3d_benchmark(
4339
median_filter,
4440
cp.asarray(host_data, dtype=cp.float32),
4541
kernel_size=kernel_size,
46-
axis=None,
4742
)
4843

4944

@@ -76,7 +71,7 @@ def test_median_filter3d_wrong_dtype(data):
7671

7772

7873
def test_median_filter3d(data):
79-
filtered_data = median_filter(data, kernel_size=3, axis=None).get()
74+
filtered_data = median_filter(data, kernel_size=3).get()
8075

8176
assert filtered_data.ndim == 3
8277
assert_allclose(np.mean(filtered_data), 808.753494, rtol=eps)
@@ -88,49 +83,25 @@ def test_median_filter3d(data):
8883
assert filtered_data.flags.c_contiguous
8984

9085
assert (
91-
median_filter(data.astype(cp.float32), kernel_size=5, axis=None, dif=1.5)
92-
.get()
93-
.dtype
86+
median_filter(data.astype(cp.float32), kernel_size=5, dif=1.5).get().dtype
9487
== np.float32
9588
)
9689

9790

98-
@pytest.mark.parametrize("axis", [0, 1, 2])
99-
def test_median_filter2d_axes(data, axis):
100-
filtered_data = median_filter(data, kernel_size=3, axis=axis).get()
101-
102-
assert filtered_data.ndim == 3
103-
assert filtered_data.dtype == np.uint16
104-
assert filtered_data.flags.c_contiguous
105-
106-
107-
def test_median_filter2d(data):
108-
filtered_data = median_filter(data, kernel_size=3, axis=0).get()
109-
110-
assert filtered_data.ndim == 3
111-
assert_allclose(np.mean(filtered_data), 808.86170708, rtol=eps)
112-
assert_allclose(np.mean(filtered_data, axis=(1, 2)).sum(), 145595.1072753)
113-
assert_allclose(np.max(filtered_data), 1080)
114-
assert_allclose(np.min(filtered_data), 80)
115-
116-
assert filtered_data.dtype == np.uint16
117-
assert filtered_data.flags.c_contiguous
118-
119-
12091
@pytest.mark.perf
12192
def test_median_filter3d_performance(ensure_clean_memory):
12293
dev = cp.cuda.Device()
12394
data_host = np.random.random_sample(size=(450, 2160, 2560)).astype(np.float32) * 2.0
12495
data = cp.asarray(data_host, dtype=cp.float32)
12596

12697
# warm up
127-
median_filter(data, kernel_size=3, axis=None)
98+
median_filter(data, kernel_size=3)
12899
dev.synchronize()
129100

130101
start = time.perf_counter_ns()
131102
nvtx.RangePush("Core")
132103
for _ in range(10):
133-
median_filter(data, kernel_size=3, axis=None)
104+
median_filter(data, kernel_size=3)
134105
nvtx.RangePop()
135106
dev.synchronize()
136107
duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10
@@ -139,7 +110,7 @@ def test_median_filter3d_performance(ensure_clean_memory):
139110

140111

141112
def test_remove_outlier3d(data):
142-
filtered_data = remove_outlier(data, kernel_size=3, axis=None, dif=1.5).get()
113+
filtered_data = remove_outlier(data, kernel_size=3, dif=1.5).get()
143114

144115
assert filtered_data.ndim == 3
145116
assert_allclose(np.mean(filtered_data), 808.753494, rtol=eps)
@@ -154,16 +125,3 @@ def test_remove_outlier3d(data):
154125
== np.float32
155126
)
156127
assert filtered_data.flags.c_contiguous
157-
158-
159-
def test_remove_outlier2d(data):
160-
filtered_data = remove_outlier(data, kernel_size=3, axis=0, dif=1.5).get()
161-
162-
assert filtered_data.ndim == 3
163-
assert_allclose(np.mean(filtered_data), 809.049971, rtol=eps)
164-
assert_allclose(np.mean(filtered_data, axis=(1, 2)).sum(), 145628.994824)
165-
assert_allclose(np.max(filtered_data), 1136)
166-
assert_allclose(np.min(filtered_data), 62)
167-
168-
assert filtered_data.dtype == np.uint16
169-
assert filtered_data.flags.c_contiguous

0 commit comments

Comments
 (0)