Skip to content

Commit baa5669

Browse files
committed
Add input type validation to raven filter
1 parent f84fc78 commit baa5669

File tree

5 files changed

+122
-67
lines changed

5 files changed

+122
-67
lines changed

docs/source/examples/raven_filter_example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
sino_shape = sinogram.shape
2525

26-
sinogram_stack = cp.stack([sinogram] * 20, axis=1)
26+
sinogram_stack = cp.stack([sinogram] * 5, axis=1)
2727

2828
print("The shape of the sinogram stack is {}".format(cp.shape(sinogram_stack)))
2929

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
#include <cupy/complex.cuh>
22

3-
extern "C" __global__ void
3+
template <typename Type>
4+
__global__ void
45
raven_filter(
5-
complex<float> *input,
6-
complex<float> *output,
6+
complex<Type> *input,
7+
complex<Type> *output,
78
int width, int images, int height,
89
int u0, int n, int v0) {
910

@@ -17,17 +18,21 @@ raven_filter(
1718
int centerx = width / 2;
1819
int centerz = height / 2;
1920

20-
complex<float> value = input[pz * width * images + py * width + px];
21+
long long index = static_cast<long long>(px) +
22+
width * static_cast<long long>(py) +
23+
width * images * static_cast<long long>(pz);
24+
25+
complex<Type> value = input[index];
2126
if( pz >= (centerz - v0) && pz < (centerz + v0 + 1) ) {
2227

2328
// +1 needed to match with CPU implementation
24-
float base = float(px - centerx + 1) / u0;
25-
float power = base;
29+
Type base = Type(px - centerx + 1) / u0;
30+
Type power = base;
2631
for( int i = 1; i < 2 * n; i++ )
2732
power *= base;
2833

29-
float filtered_value = 1.f / (1.f + power);
30-
value *= complex<float>(filtered_value, filtered_value);
34+
Type filtered_value = 1.f / (1.f + power);
35+
value *= complex<Type>(filtered_value, filtered_value);
3136
}
3237

3338
// ifftshifting positions
@@ -36,5 +41,9 @@ raven_filter(
3641
int outX = (px + xshift) % width;
3742
int outZ = (pz + zshift) % height;
3843

39-
output[outZ * width * images + py * width + outX] = value;
44+
long long outIndex = static_cast<long long>(outX) +
45+
width * static_cast<long long>(py) +
46+
width * images * static_cast<long long>(outZ);
47+
48+
output[outIndex] = value;
4049
}

httomolibgpu/prep/stripe.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -375,9 +375,47 @@ def raven_filter(
375375
pad_x: int = 20,
376376
pad_method: str = "edge"):
377377
"""
378-
Raven filter
378+
Applies raven filter to a 3D CuPy array. For more detailed information, see :ref:`method_raven_filter`.
379+
380+
Parameters
381+
----------
382+
data : cp.ndarray
383+
Input CuPy 3D array either float32 or uint16 data type.
384+
385+
pad_y : int, optional
386+
Pad the top and bottom of projections.
387+
388+
pad_x : int, optional
389+
Pad the left and right of projections.
390+
391+
pad_method : str, optional
392+
Numpy pad method to use.
393+
394+
uvalue : int, optional
395+
The shape of filter.
396+
397+
nvalue : int, optional
398+
The shape of filter.
399+
400+
vvalue : int, optional
401+
The number of rows to be applied the filter
402+
403+
Returns
404+
-------
405+
ndarray
406+
Raven filtered 3D CuPy array in float32 data type.
407+
408+
Raises
409+
------
410+
ValueError
411+
If the input array is not three dimensional.
379412
"""
380413

414+
input_type = sinogram.dtype
415+
416+
if input_type not in ["float32", "float64"]:
417+
raise ValueError("The input data should be either float32 or float64 data type")
418+
381419
# Padding of the sinogram
382420
sinogram = cp.pad(sinogram, ((pad_y, pad_y), (0, 0), (pad_x, pad_x)), mode=pad_method)
383421

@@ -388,6 +426,11 @@ def raven_filter(
388426
# Setup various values for the filter
389427
height, images, width = sinogram.shape
390428

429+
# Set the input type of the kernel
430+
kernel_args = "raven_filter<{0}>".format(
431+
"float" if input_type == "float32" else "double"
432+
)
433+
391434
# setting grid/block parameters
392435
block_x = 128
393436
block_dims = (block_x, 1, 1)
@@ -397,8 +440,8 @@ def raven_filter(
397440
grid_dims = (grid_x, grid_y, grid_z)
398441
params = (fft_data_shifted, fft_data, width, images, height, uvalue, nvalue, vvalue)
399442

400-
raven_module = load_cuda_module("raven_filter")
401-
raven_filt = raven_module.get_function("raven_filter")
443+
raven_module = load_cuda_module("raven_filter", name_expressions=[kernel_args])
444+
raven_filt = raven_module.get_function(kernel_args)
402445

403446
raven_filt(grid_dims, block_dims, params)
404447

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import numpy as np
2+
import pyfftw
3+
import pyfftw.interfaces.numpy_fft as fft
4+
5+
def raven_filter_cpu(
6+
sinogram,
7+
uvalue: int = 20,
8+
nvalue: int = 4,
9+
vvalue: int = 2,
10+
pad_y: int = 20,
11+
pad_x: int = 20,
12+
pad_method: str = "edge"):
13+
14+
# Parameters
15+
v0 = vvalue
16+
n = nvalue
17+
u0 = uvalue
18+
19+
# Make a padded copy
20+
sinogram_padded = np.pad(sinogram, ((pad_y,pad_y), (0, 0), (pad_x,pad_x)), pad_method)
21+
22+
# Size
23+
height, images, width = sinogram_padded.shape
24+
25+
# Generate filter function
26+
centerx = np.ceil(width / 2.0) - 1.0
27+
centery = np.int16(np.ceil(height / 2.0) - 1)
28+
row1 = centery - v0
29+
row2 = centery + v0 + 1
30+
listx = np.arange(width) - centerx
31+
filtershape = 1.0 / (1.0 + np.power(listx / u0, 2 * n))
32+
filtershapepad2d = np.zeros((row2 - row1, filtershape.size))
33+
filtershapepad2d[:] = np.float64(filtershape)
34+
filtercomplex = filtershapepad2d + filtershapepad2d * 1j
35+
36+
# Generate filter objects
37+
a = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
38+
b = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
39+
c = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
40+
d = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
41+
fft_object = pyfftw.FFTW(a, b, axes=(0, 2))
42+
ifft_object = pyfftw.FFTW(c, d, axes=(0, 2), direction='FFTW_BACKWARD')
43+
44+
sino = fft.fftshift(fft_object(sinogram_padded), axes=(0, 2))
45+
for m in range(sino.shape[1]):
46+
sino[row1:row2, m] = sino[row1:row2, m] * filtercomplex
47+
sino = ifft_object(fft.ifftshift(sino, axes=(0, 2)))
48+
sinogram = sino[pad_y:height-pad_y, :, pad_x:width-pad_x]
49+
50+
return sinogram.real

tests/test_prep/test_stripe.py

Lines changed: 7 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from cupy.cuda import nvtx
44
import numpy as np
55
import pytest
6-
import pyfftw
7-
import pyfftw.interfaces.numpy_fft as fft
6+
87
from httomolibgpu.prep.normalize import normalize
98
from httomolibgpu.prep.stripe import (
109
remove_stripe_based_sorting,
@@ -13,53 +12,7 @@
1312
raven_filter,
1413
)
1514
from numpy.testing import assert_allclose
16-
17-
def raven_filter_cpu(
18-
sinogram,
19-
uvalue: int = 20,
20-
nvalue: int = 4,
21-
vvalue: int = 2,
22-
pad_y: int = 20,
23-
pad_x: int = 20,
24-
pad_method: str = "edge"):
25-
26-
# Parameters
27-
v0 = vvalue
28-
n = nvalue
29-
u0 = uvalue
30-
31-
# Make a padded copy
32-
sinogram_padded = cp.pad(sinogram, ((pad_y,pad_y), (0, 0), (pad_x,pad_x)), pad_method).get()
33-
34-
# Size
35-
height, images, width = sinogram_padded.shape
36-
37-
# Generate filter function
38-
centerx = np.ceil(width / 2.0) - 1.0
39-
centery = np.int16(np.ceil(height / 2.0) - 1)
40-
row1 = centery - v0
41-
row2 = centery + v0 + 1
42-
listx = np.arange(width) - centerx
43-
filtershape = 1.0 / (1.0 + np.power(listx / u0, 2 * n))
44-
filtershapepad2d = np.zeros((row2 - row1, filtershape.size))
45-
filtershapepad2d[:] = np.float64(filtershape)
46-
filtercomplex = filtershapepad2d + filtershapepad2d * 1j
47-
48-
# Generate filter objects
49-
a = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
50-
b = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
51-
c = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
52-
d = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
53-
fft_object = pyfftw.FFTW(a, b, axes=(0, 2))
54-
ifft_object = pyfftw.FFTW(c, d, axes=(0, 2), direction='FFTW_BACKWARD')
55-
56-
sino = fft.fftshift(fft_object(sinogram_padded), axes=(0, 2))
57-
for m in range(sino.shape[1]):
58-
sino[row1:row2, m] = sino[row1:row2, m] * filtercomplex
59-
sino = ifft_object(fft.ifftshift(sino, axes=(0, 2)))
60-
sinogram = sino[pad_y:height-pad_y, :, pad_x:width-pad_x]
61-
62-
return sinogram.real
15+
from .stripe_cpu_reference import raven_filter_cpu
6316

6417
def test_remove_stripe_ti_on_data(data, flats, darks):
6518
# --- testing the CuPy implementation from TomoCupy ---#
@@ -119,10 +72,10 @@ def test_stripe_raven_cupy(data, flats, darks):
11972

12073
data = normalize(data, flats, darks, cutoff=10, minus_log=True)
12174

122-
data_after_raven_gpu = raven_filter(cp.copy(data)).get()
123-
data_after_raven_cpu = raven_filter_cpu(cp.copy(data))
75+
data_after_raven_gpu = raven_filter(cp.copy(data))
76+
data_after_raven_cpu = cp.asarray(raven_filter_cpu(cp.copy(data).get()))
12477

125-
assert_allclose(data_after_raven_cpu, data_after_raven_gpu, 0, atol=4e-01)
78+
cp.testing.assert_allclose(data_after_raven_cpu, data_after_raven_gpu, rtol=0, atol=4e-01)
12679

12780
data = None #: free up GPU memory
12881
# make sure the output is float32
@@ -210,11 +163,11 @@ def test_raven_filter_cpu_performance(ensure_clean_memory):
210163
data = cp.asarray(data_host, dtype=np.float32)
211164

212165
# do a cold run first
213-
raven_filter_cpu(cp.copy(data))
166+
raven_filter_cpu(cp.copy(data).get())
214167

215168
start = time.perf_counter_ns()
216169
for _ in range(10):
217-
raven_filter_cpu(cp.copy(data))
170+
raven_filter_cpu(cp.copy(data).get())
218171

219172
duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10
220173

0 commit comments

Comments
 (0)