Skip to content

Commit 444e048

Browse files
committed
Update raven filter and add example
1 parent f478663 commit 444e048

File tree

4 files changed

+147
-51
lines changed

4 files changed

+147
-51
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import os
2+
import numpy as np
3+
import cupy as cp
4+
import scipy
5+
import pyfftw
6+
import pyfftw.interfaces.numpy_fft as fft
7+
import httomolibgpu
8+
from httomolibgpu.misc.raven_filter import raven_filter
9+
10+
import matplotlib.pyplot as plt
11+
12+
# Load the sinogram data
13+
path_lib = os.path.dirname(httomolibgpu.__file__)
14+
in_file = os.path.abspath(
15+
os.path.join(path_lib, "..", "tests/test_data/", "3600proj_sino.npz")
16+
)
17+
l_infile = np.load(in_file)
18+
sinogram = l_infile["sinogram"]
19+
angles = l_infile["angles"]
20+
sinogram = cp.asarray(sinogram)
21+
22+
sino_shape = sinogram.shape
23+
24+
print("The shape of the sinogram is {}".format(cp.shape(sinogram)))
25+
26+
# Make a numpy copy
27+
sinogram_padded = np.pad(sinogram.get(), 20, "edge")
28+
29+
# GPU filter
30+
sinogram_gpu_filter = raven_filter(sinogram)
31+
32+
# Size
33+
width1 = sino_shape[1] + 2 * 20
34+
height1 = sino_shape[0] + 2 * 20
35+
36+
# Parameters
37+
v0 = 2
38+
u0 = 20
39+
n = 2
40+
41+
# Generate filter function
42+
centerx = np.ceil(width1 / 2.0) - 1.0
43+
centery = np.int16(np.ceil(height1 / 2.0) - 1)
44+
row1 = centery - v0
45+
row2 = centery + v0 + 1
46+
listx = np.arange(width1) - centerx
47+
filtershape = 1.0 / (1.0 + np.power(listx / u0, 2 * n))
48+
filtershapepad2d = np.zeros((row2 - row1, filtershape.size))
49+
filtershapepad2d[:] = np.float64(filtershape)
50+
filtercomplex = filtershapepad2d + filtershapepad2d * 1j
51+
52+
# Generate filter objects
53+
a = pyfftw.empty_aligned((height1, width1), dtype='complex128', n=16)
54+
b = pyfftw.empty_aligned((height1, width1), dtype='complex128', n=16)
55+
c = pyfftw.empty_aligned((height1, width1), dtype='complex128', n=16)
56+
d = pyfftw.empty_aligned((height1, width1), dtype='complex128', n=16)
57+
fft_object = pyfftw.FFTW(a, b, axes=(0, 1))
58+
ifft_object = pyfftw.FFTW(c, d, axes=(0, 1), direction='FFTW_BACKWARD')
59+
60+
sino = fft.fftshift(fft_object(sinogram_padded))
61+
sino[row1:row2] = sino[row1:row2] * filtercomplex
62+
sino = ifft_object(fft.ifftshift(sino))
63+
64+
plt.figure()
65+
66+
#subplot(r,c) provide the no. of rows and columns
67+
f, axarr = plt.subplots(2,2)
68+
69+
# use the created array to output your multiple images. In this case I have stacked 4 images vertically
70+
axarr[0, 0].imshow(sinogram_padded)
71+
axarr[0, 1].imshow(sinogram_padded - sinogram_gpu_filter.get().real)
72+
axarr[1, 0].imshow(sinogram_padded - sino.real)
73+
axarr[1, 1].imshow(sinogram_gpu_filter.get().real - sino.real)
74+
75+
plt.show()
76+
Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,36 @@
1+
#include <cupy/complex.cuh>
12

2-
template <typename Type, int diameter>
3-
__global__ void raven__filter_kernel3d(Type *in, int Z, int M, int N) {
3+
extern "C" __global__ void
4+
raven_filter(complex<float> *input, complex<float> *output, int width1, int height1, int u0, int n, int v0) {
45

5-
const long i = blockDim.x * blockIdx.x + threadIdx.x;
6-
const long j = blockDim.y * blockIdx.y + threadIdx.y;
7-
const long k = blockDim.z * blockIdx.z + threadIdx.z;
6+
int centerx = (width1 + 1) / 2 - 1;
7+
int centery = (height1 + 1) / 2 - 1;
88

9-
if (i >= N || j >= M || k >= Z)
9+
int px = threadIdx.x + blockIdx.x * blockDim.x;
10+
int py = threadIdx.y + blockIdx.y * blockDim.y;
11+
12+
if (px >= width1)
13+
return;
14+
if (py >= height1)
1015
return;
1116

12-
17+
complex<float> value = input[py * width1 + px];
18+
if( py >= (centery - v0) && py <= (centery + v0 + 1) ) {
19+
20+
double base = (px - centerx) / u0;
21+
double filtered_value = base;
22+
for( int i = 1; i < 2 * n; i++ )
23+
filtered_value *= base;
24+
25+
filtered_value = 1.0f / (1.0 + filtered_value);
26+
value *= complex<float>(filtered_value, filtered_value);
27+
}
28+
29+
// ifftshifting positions
30+
int xshift = (width1 + 1) / 2;
31+
int yshift = (height1 + 1) / 2;
32+
int outX = (px + xshift) % width1;
33+
int outY = (py + yshift) % height1;
34+
35+
output[outY * width1 + outX] = value;
1336
}

httomolibgpu/misc/raven_filter.py

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@
4949
]
5050

5151

52-
def raven_filter_savu(
52+
def raven_filter(
5353
data: cp.ndarray,
54-
kernel_size: int = 3,
55-
pad_y: int = 100,
56-
pad_x: int = 100,
54+
pad_y: int = 20,
55+
pad_x: int = 20,
5756
pad_method: str = "edge",
57+
uvalue: int = 20,
58+
nvalue: int = 4,
59+
vvalue: int = 2,
5860
) -> cp.ndarray:
5961
"""
6062
Applies raven filter to a 3D CuPy array. For more detailed information, see :ref:`method_raven_filter`.
@@ -63,8 +65,6 @@ def raven_filter_savu(
6365
----------
6466
data : cp.ndarray
6567
Input CuPy 3D array either float32 or uint16 data type.
66-
kernel_size : int, optional
67-
The size of the filter's kernel (a diameter).
6868
6969
pad_y : int, optional
7070
Pad the top and bottom of projections.
@@ -75,10 +75,19 @@ def raven_filter_savu(
7575
pad_method : str, optional
7676
Numpy pad method to use.
7777
78+
uvalue : int, optional
79+
The shape of filter.
80+
81+
nvalue : int, optional
82+
The shape of filter.
83+
84+
vvalue : int, optional
85+
The number of rows to be applied the filter
86+
7887
Returns
7988
-------
8089
ndarray
81-
Raven filtered 3D CuPy array either float32 or uint16 data type.
90+
Raven filtered 3D CuPy array in float32 data type.
8291
8392
Raises
8493
------
@@ -87,55 +96,40 @@ def raven_filter_savu(
8796
"""
8897
input_type = data.dtype
8998

90-
if input_type not in ["float32", "uint16"]:
91-
raise ValueError("The input data should be either float32 or uint16 data type")
92-
93-
if data.ndim == 3:
94-
if 0 in data.shape:
95-
raise ValueError("The length of one of dimensions is equal to zero")
96-
else:
97-
raise ValueError("The input array must be a 3D array")
98-
99-
if kernel_size not in [3, 5, 7, 9, 11, 13]:
100-
raise ValueError("Please select a correct kernel size: 3, 5, 7, 9, 11, 13")
101-
99+
if input_type not in ["float32"]:
100+
raise ValueError("The input data should be float32")
102101

103-
dz_orig, dy_orig, dx_orig = data.shape
102+
# if data.ndim != 3:
103+
# raise ValueError("only 3D data is supported")
104104

105-
padded_data, pad_tup = cp.pad(data, fftpad, "edge")
106-
dz, dy, dx = padded_data.shape
105+
# Padding of the data
106+
padded_data = cp.pad(data, ((pad_y, pad_y), (pad_x, pad_x)), mode=pad_method)
107107

108-
# 3D FFT of data
109-
padded_data = cp.pad(data, ((0, 0), (pad_y, pad_y), (pad_x, pad_x)), mode=pad_method)
108+
# FFT and shift of data
110109
fft_data = fft2(padded_data, axes=(-2, -1), overwrite_x=True)
111110
fft_data_shifted = fftshift(fft_data)
112111

113112
# Setup various values for the filter
114-
_, height, width = data.shape
113+
height, width = data.shape
115114

116115
height1 = height + 2 * pad_y
117116
width1 = width + 2 * pad_x
118117

119-
# raven
120-
kernel_args = "raven_general_kernel3d<{0}, {1}>".format(
121-
"float" if input_type == "float32" else "unsigned short", kernel_size
122-
)
123-
block_x = 128
124118
# setting grid/block parameters
119+
block_x = 128
125120
block_dims = (block_x, 1, 1)
126-
grid_x = (dx + block_x - 1) // block_x
127-
grid_y = dy
128-
grid_z = dz
129-
grid_dims = (grid_x, grid_y, grid_z)
130-
params = (fft_data_shifted, dz, dy, dx)
131-
132-
raven_module = load_cuda_module("raven_kernel", name_expressions=[kernel_args])
133-
raven_filt = raven_module.get_function(kernel_args)
134-
121+
grid_x = (width1 + block_x - 1) // block_x
122+
grid_y = height1
123+
grid_dims = (grid_x, grid_y, 1)
124+
params = (fft_data_shifted, fft_data, width1, height1, uvalue, nvalue, vvalue)
125+
126+
raven_module = load_cuda_module("raven_filter")
127+
raven_filt = raven_module.get_function("raven_filter")
128+
135129
raven_filt(grid_dims, block_dims, params)
136-
137-
fft_data = fftshift(fft_data_shifted)
138-
139-
data = ifft2(fft_data, axes=(-2, -1), overwrite_x=True, norm="forward")
130+
131+
# raven_fil already doing ifftshifting
132+
# fft_data = ifftshift(fft_data_shifted)
133+
data = ifft2(fft_data, axes=(-2, -1), overwrite_x=True)
140134

141135
return data

httomolibgpu/prep/stripe.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
from unittest.mock import Mock
3030

3131
if cupy_run:
32-
from cupyx.scipy.ndimage import median_filter, binary_dilation, raven_filter, uniform_filter1d
32+
from cupyx.scipy.ndimage import median_filter, binary_dilation, uniform_filter1d
33+
from httomolibgpu.misc.raven_filter import (
34+
raven_filter,
35+
)
3336
else:
3437
median_filter = Mock()
3538
binary_dilation = Mock()

0 commit comments

Comments
 (0)