Skip to content

Commit a41819d

Browse files
committed
Add raven filter
1 parent a77f334 commit a41819d

File tree

3 files changed

+204
-1
lines changed

3 files changed

+204
-1
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#define IDX2R(i,j,N) (((i)*(N))+(j))
2+
3+
template <typename Type>
4+
__global__ void fftshift_2D(Type *out, int N1, int N2)
5+
{
6+
int i = threadIdx.y + blockDim.y * blockIdx.y;
7+
int j = threadIdx.x + blockDim.x * blockIdx.x;
8+
9+
if (i < N1 && j < N2)
10+
{
11+
double a = 1-2*((i+j)&1);
12+
data[IDX2R(i,j,N2)].x *= a;
13+
data[IDX2R(i,j,N2)].y *= a;
14+
}
15+
}
16+
17+
template <typename Type, int diameter>
18+
__global__ void raven__filter_kernel3d(const Type *in, Type *out, float dif,
19+
int Z, int M, int N) {
20+
constexpr int radius = diameter / 2;
21+
constexpr int d3 = diameter * diameter * diameter;
22+
constexpr int midpoint = d3 / 2;
23+
24+
Type ValVec[d3];
25+
const long i = blockDim.x * blockIdx.x + threadIdx.x;
26+
const long j = blockDim.y * blockIdx.y + threadIdx.y;
27+
const long k = blockDim.z * blockIdx.z + threadIdx.z;
28+
29+
if (i >= N || j >= M || k >= Z)
30+
return;
31+
32+
long long index = static_cast<long long>(i) + N * static_cast<long long>(j) + N * M * static_cast<long long>(k);
33+
34+
int counter = 0;
35+
for (int i_m = -radius; i_m <= radius; i_m++) {
36+
long long i1 = i + i_m; // using long long to avoid integer overflows
37+
if ((i1 < 0) || (i1 >= N))
38+
i1 = i;
39+
for (int j_m = -radius; j_m <= radius; j_m++) {
40+
long long j1 = j + j_m;
41+
if ((j1 < 0) || (j1 >= M))
42+
j1 = j;
43+
for (int k_m = -radius; k_m <= radius; k_m++) {
44+
long long k1 = k + k_m;
45+
if ((k1 < 0) || (k1 >= Z))
46+
k1 = k;
47+
ValVec[counter] = in[i1 + N * j1 + N * M * k1];
48+
counter++;
49+
}
50+
}
51+
}
52+
53+
/* do bubble sort here */
54+
for (int x = 0; x < d3 - 1; x++) {
55+
for (int y = 0; y < d3 - x - 1; y++) {
56+
if (ValVec[y] > ValVec[y + 1]) {
57+
Type temp = ValVec[y];
58+
ValVec[y] = ValVec[y + 1];
59+
ValVec[y + 1] = temp;
60+
}
61+
}
62+
}
63+
64+
if (dif > 0.0f) {
65+
/* perform dezingering */
66+
out[index] =
67+
fabsf(in[index] - ValVec[midpoint]) >= dif ? ValVec[midpoint] : in[index];
68+
}
69+
else out[index] = ValVec[midpoint]; /* median filtering */
70+
}

httomolibgpu/misc/raven_filter.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
# ---------------------------------------------------------------------------
4+
# Copyright 2022 Diamond Light Source Ltd.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
# ---------------------------------------------------------------------------
18+
# Created By : Tomography Team at DLS <[email protected]>
19+
# Created Date: 21/October/2022
20+
# ---------------------------------------------------------------------------
21+
""" Module for data correction. For more detailed information see :ref:`data_correction_module`.
22+
23+
"""
24+
25+
import numpy as np
26+
from typing import Union
27+
28+
from httomolibgpu import cupywrapper
29+
30+
cp = cupywrapper.cp
31+
cupy_run = cupywrapper.cupy_run
32+
33+
from numpy import float32
34+
from unittest.mock import Mock
35+
36+
if cupy_run:
37+
from httomolibgpu.cuda_kernels import load_cuda_module
38+
else:
39+
load_cuda_module = Mock()
40+
41+
42+
__all__ = [
43+
"raven_filter",
44+
]
45+
46+
47+
def raven_filter(
48+
data: cp.ndarray,
49+
kernel_size: int = 3,
50+
dif: float = 0.0,
51+
) -> cp.ndarray:
52+
"""
53+
Applies raven filter to a 3D CuPy array. For more detailed information, see :ref:`method_raven_filter`.
54+
55+
Parameters
56+
----------
57+
data : cp.ndarray
58+
Input CuPy 3D array either float32 or uint16 data type.
59+
kernel_size : int, optional
60+
The size of the filter's kernel (a diameter).
61+
dif : float, optional
62+
Expected difference value between outlier value and the
63+
median value of the array, leave equal to 0 for classical median.
64+
65+
Returns
66+
-------
67+
ndarray
68+
Median filtered 3D CuPy array either float32 or uint16 data type.
69+
70+
Raises
71+
------
72+
ValueError
73+
If the input array is not three dimensional.
74+
"""
75+
input_type = data.dtype
76+
77+
if input_type not in ["float32", "uint16"]:
78+
raise ValueError("The input data should be either float32 or uint16 data type")
79+
80+
if data.ndim == 3:
81+
if 0 in data.shape:
82+
raise ValueError("The length of one of dimensions is equal to zero")
83+
else:
84+
raise ValueError("The input array must be a 3D array")
85+
86+
if kernel_size not in [3, 5, 7, 9, 11, 13]:
87+
raise ValueError("Please select a correct kernel size: 3, 5, 7, 9, 11, 13")
88+
89+
dz, dy, dx = data.shape
90+
output = cp.copy(data, order="C")
91+
92+
# 3d median or dezinger
93+
kernel_args = "median_general_kernel3d<{0}, {1}>".format(
94+
"float" if input_type == "float32" else "unsigned short", kernel_size
95+
)
96+
block_x = 128
97+
# setting grid/block parameters
98+
block_dims = (block_x, 1, 1)
99+
grid_x = (dx + block_x - 1) // block_x
100+
grid_y = dy
101+
grid_z = dz
102+
grid_dims = (grid_x, grid_y, grid_z)
103+
params = (data, output, cp.float32(dif), dz, dy, dx)
104+
105+
median_module = load_cuda_module("raven_filter", name_expressions=[kernel_args])
106+
median_filt = median_module.get_function(kernel_args)
107+
108+
median_filt(grid_dims, block_dims, params)
109+
110+
return output

httomolibgpu/prep/stripe.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,12 @@
2929
from unittest.mock import Mock
3030

3131
if cupy_run:
32-
from cupyx.scipy.ndimage import median_filter, binary_dilation, uniform_filter1d
32+
from cupyx.scipy.ndimage import median_filter, binary_dilation, raven_filter, uniform_filter1d
3333
else:
3434
median_filter = Mock()
3535
binary_dilation = Mock()
3636
uniform_filter1d = Mock()
37+
raven_filter = Mock()
3738

3839
from typing import Union
3940

@@ -359,6 +360,28 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True):
359360
sinogram = _rs_large(sinogram, snr, size, matindex)
360361
return sinogram
361362

363+
def _raven_filter(sinogram, snr, size, matindex, vvalue=10, uvalue=10, nvalue=10 ):
364+
"""
365+
Raven filter
366+
"""
367+
padding = 2
368+
(nrow, ncol) = sinogram.shape
369+
width1 = nrow + 2 * padding #sino_shape[1] + 2 * self.pad
370+
height1 = ncol + 2 * padding #sino_shape[0] + 2 * self.pad
371+
372+
# Create filter
373+
centerx = np.ceil(width1 / 2.0) - 1.0
374+
centery = np.int16(np.ceil(height1 / 2.0) - 1)
375+
row1 = centery - vvalue
376+
row2 = centery + vvalue + 1
377+
listx = np.arange(width1) - centerx
378+
filtershape = 1.0 / (1.0 + np.power(listx / uvalue, 2 * nvalue))
379+
filtershapepad2d = np.zeros((self.row2 - self.row1, filtershape.size))
380+
filtershapepad2d[:] = np.float64(filtershape)
381+
filtercomplex = filtershapepad2d + filtershapepad2d * 1j
382+
383+
384+
return sinogram
362385

363386
def _create_matindex(nrow, ncol):
364387
"""

0 commit comments

Comments
 (0)