Skip to content

Commit f478663

Browse files
committed
Update raven filter
1 parent a41819d commit f478663

File tree

2 files changed

+48
-74
lines changed

2 files changed

+48
-74
lines changed
Lines changed: 2 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,13 @@
1-
#define IDX2R(i,j,N) (((i)*(N))+(j))
2-
3-
template <typename Type>
4-
__global__ void fftshift_2D(Type *out, int N1, int N2)
5-
{
6-
int i = threadIdx.y + blockDim.y * blockIdx.y;
7-
int j = threadIdx.x + blockDim.x * blockIdx.x;
8-
9-
if (i < N1 && j < N2)
10-
{
11-
double a = 1-2*((i+j)&1);
12-
data[IDX2R(i,j,N2)].x *= a;
13-
data[IDX2R(i,j,N2)].y *= a;
14-
}
15-
}
161

172
template <typename Type, int diameter>
18-
__global__ void raven__filter_kernel3d(const Type *in, Type *out, float dif,
19-
int Z, int M, int N) {
20-
constexpr int radius = diameter / 2;
21-
constexpr int d3 = diameter * diameter * diameter;
22-
constexpr int midpoint = d3 / 2;
3+
__global__ void raven__filter_kernel3d(Type *in, int Z, int M, int N) {
234

24-
Type ValVec[d3];
255
const long i = blockDim.x * blockIdx.x + threadIdx.x;
266
const long j = blockDim.y * blockIdx.y + threadIdx.y;
277
const long k = blockDim.z * blockIdx.z + threadIdx.z;
288

299
if (i >= N || j >= M || k >= Z)
3010
return;
3111

32-
long long index = static_cast<long long>(i) + N * static_cast<long long>(j) + N * M * static_cast<long long>(k);
33-
34-
int counter = 0;
35-
for (int i_m = -radius; i_m <= radius; i_m++) {
36-
long long i1 = i + i_m; // using long long to avoid integer overflows
37-
if ((i1 < 0) || (i1 >= N))
38-
i1 = i;
39-
for (int j_m = -radius; j_m <= radius; j_m++) {
40-
long long j1 = j + j_m;
41-
if ((j1 < 0) || (j1 >= M))
42-
j1 = j;
43-
for (int k_m = -radius; k_m <= radius; k_m++) {
44-
long long k1 = k + k_m;
45-
if ((k1 < 0) || (k1 >= Z))
46-
k1 = k;
47-
ValVec[counter] = in[i1 + N * j1 + N * M * k1];
48-
counter++;
49-
}
50-
}
51-
}
52-
53-
/* do bubble sort here */
54-
for (int x = 0; x < d3 - 1; x++) {
55-
for (int y = 0; y < d3 - x - 1; y++) {
56-
if (ValVec[y] > ValVec[y + 1]) {
57-
Type temp = ValVec[y];
58-
ValVec[y] = ValVec[y + 1];
59-
ValVec[y + 1] = temp;
60-
}
61-
}
62-
}
63-
64-
if (dif > 0.0f) {
65-
/* perform dezingering */
66-
out[index] =
67-
fabsf(in[index] - ValVec[midpoint]) >= dif ? ValVec[midpoint] : in[index];
68-
}
69-
else out[index] = ValVec[midpoint]; /* median filtering */
12+
7013
}

httomolibgpu/misc/raven_filter.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,26 @@
3535

3636
if cupy_run:
3737
from httomolibgpu.cuda_kernels import load_cuda_module
38+
from cupyx.scipy.fft import fft2, ifft2, fftshift, ifftshift
3839
else:
3940
load_cuda_module = Mock()
41+
fft2 = Mock()
42+
ifft2 = Mock()
43+
fftshift = Mock()
44+
ifftshift = Mock()
4045

4146

4247
__all__ = [
4348
"raven_filter",
4449
]
4550

4651

47-
def raven_filter(
52+
def raven_filter_savu(
4853
data: cp.ndarray,
4954
kernel_size: int = 3,
50-
dif: float = 0.0,
55+
pad_y: int = 100,
56+
pad_x: int = 100,
57+
pad_method: str = "edge",
5158
) -> cp.ndarray:
5259
"""
5360
Applies raven filter to a 3D CuPy array. For more detailed information, see :ref:`method_raven_filter`.
@@ -58,14 +65,20 @@ def raven_filter(
5865
Input CuPy 3D array either float32 or uint16 data type.
5966
kernel_size : int, optional
6067
The size of the filter's kernel (a diameter).
61-
dif : float, optional
62-
Expected difference value between outlier value and the
63-
median value of the array, leave equal to 0 for classical median.
68+
69+
pad_y : int, optional
70+
Pad the top and bottom of projections.
71+
72+
pad_x : int, optional
73+
Pad the left and right of projections.
74+
75+
pad_method : str, optional
76+
Numpy pad method to use.
6477
6578
Returns
6679
-------
6780
ndarray
68-
Median filtered 3D CuPy array either float32 or uint16 data type.
81+
Raven filtered 3D CuPy array either float32 or uint16 data type.
6982
7083
Raises
7184
------
@@ -86,11 +99,25 @@ def raven_filter(
8699
if kernel_size not in [3, 5, 7, 9, 11, 13]:
87100
raise ValueError("Please select a correct kernel size: 3, 5, 7, 9, 11, 13")
88101

89-
dz, dy, dx = data.shape
90-
output = cp.copy(data, order="C")
91102

92-
# 3d median or dezinger
93-
kernel_args = "median_general_kernel3d<{0}, {1}>".format(
103+
dz_orig, dy_orig, dx_orig = data.shape
104+
105+
padded_data, pad_tup = cp.pad(data, fftpad, "edge")
106+
dz, dy, dx = padded_data.shape
107+
108+
# 3D FFT of data
109+
padded_data = cp.pad(data, ((0, 0), (pad_y, pad_y), (pad_x, pad_x)), mode=pad_method)
110+
fft_data = fft2(padded_data, axes=(-2, -1), overwrite_x=True)
111+
fft_data_shifted = fftshift(fft_data)
112+
113+
# Setup various values for the filter
114+
_, height, width = data.shape
115+
116+
height1 = height + 2 * pad_y
117+
width1 = width + 2 * pad_x
118+
119+
# raven
120+
kernel_args = "raven_general_kernel3d<{0}, {1}>".format(
94121
"float" if input_type == "float32" else "unsigned short", kernel_size
95122
)
96123
block_x = 128
@@ -100,11 +127,15 @@ def raven_filter(
100127
grid_y = dy
101128
grid_z = dz
102129
grid_dims = (grid_x, grid_y, grid_z)
103-
params = (data, output, cp.float32(dif), dz, dy, dx)
130+
params = (fft_data_shifted, dz, dy, dx)
131+
132+
raven_module = load_cuda_module("raven_kernel", name_expressions=[kernel_args])
133+
raven_filt = raven_module.get_function(kernel_args)
134+
135+
raven_filt(grid_dims, block_dims, params)
104136

105-
median_module = load_cuda_module("raven_filter", name_expressions=[kernel_args])
106-
median_filt = median_module.get_function(kernel_args)
137+
fft_data = fftshift(fft_data_shifted)
107138

108-
median_filt(grid_dims, block_dims, params)
139+
data = ifft2(fft_data, axes=(-2, -1), overwrite_x=True, norm="forward")
109140

110-
return output
141+
return data

0 commit comments

Comments
 (0)