Skip to content

Commit 2979406

Browse files
committed
Add raven filter to stripe.py and add performance test
1 parent e66c17a commit 2979406

File tree

3 files changed

+77
-22
lines changed

3 files changed

+77
-22
lines changed

docs/source/examples/raven_filter_example.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from httomolibgpu.misc.raven_filter import raven_filter
99

1010
import matplotlib.pyplot as plt
11+
import time
12+
1113

1214
# Load the sinogram data
1315
path_lib = os.path.dirname(httomolibgpu.__file__)
@@ -31,8 +33,12 @@
3133
# Make a numpy copy
3234
sinogram_padded = np.pad(sinogram.get(), 20, "edge")
3335

36+
start_time = time.time()
3437
# GPU filter
3538
sinogram_gpu_filter = raven_filter(sinogram, u0, n, v0)
39+
print("--- %s seconds ---" % (time.time() - start_time))
40+
41+
start_time = time.time()
3642

3743
# Size
3844
width1 = sino_shape[1] + 2 * 20
@@ -61,6 +67,8 @@
6167
sino[row1:row2] = sino[row1:row2] * filtercomplex
6268
sino = ifft_object(fft.ifftshift(sino))
6369

70+
print("--- %s seconds ---" % (time.time() - start_time))
71+
6472
#subplot(r,c) provide the no. of rows and columns
6573
f, axarr = plt.subplots(2,2)
6674

httomolibgpu/prep/stripe.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@
3030

3131
if cupy_run:
3232
from cupyx.scipy.ndimage import median_filter, binary_dilation, uniform_filter1d
33-
from httomolibgpu.misc.raven_filter import (
34-
raven_filter,
35-
)
33+
from cupyx.scipy.fft import fft2, ifft2, fftshift
34+
from httomolibgpu.cuda_kernels import load_cuda_module
3635
else:
3736
median_filter = Mock()
3837
binary_dilation = Mock()
3938
uniform_filter1d = Mock()
40-
raven_filter = Mock()
39+
fft2 = Mock()
40+
ifft2 = Mock()
41+
fftshift = Mock()
4142

4243
from typing import Union
4344

@@ -363,26 +364,48 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True):
363364
sinogram = _rs_large(sinogram, snr, size, matindex)
364365
return sinogram
365366

366-
def _raven_filter(sinogram, snr, size, matindex, vvalue=10, uvalue=10, nvalue=10 ):
367+
def raven_filter(
368+
sinogram,
369+
uvalue: int = 20,
370+
nvalue: int = 4,
371+
vvalue: int = 2,
372+
pad_y: int = 20,
373+
pad_x: int = 20,
374+
pad_method: str = "edge"):
367375
"""
368376
Raven filter
369377
"""
370-
padding = 2
371-
(nrow, ncol) = sinogram.shape
372-
width1 = nrow + 2 * padding #sino_shape[1] + 2 * self.pad
373-
height1 = ncol + 2 * padding #sino_shape[0] + 2 * self.pad
374-
375-
# Create filter
376-
centerx = np.ceil(width1 / 2.0) - 1.0
377-
centery = np.int16(np.ceil(height1 / 2.0) - 1)
378-
row1 = centery - vvalue
379-
row2 = centery + vvalue + 1
380-
listx = np.arange(width1) - centerx
381-
filtershape = 1.0 / (1.0 + np.power(listx / uvalue, 2 * nvalue))
382-
filtershapepad2d = np.zeros((self.row2 - self.row1, filtershape.size))
383-
filtershapepad2d[:] = np.float64(filtershape)
384-
filtercomplex = filtershapepad2d + filtershapepad2d * 1j
385378

379+
# Padding of the data
380+
padded_data = cp.pad(sinogram, ((0, 0), (pad_y, pad_y), (pad_x, pad_x)), mode=pad_method)
381+
# padded_data = cp.pad(sinogram, ((pad_y, pad_y), (pad_x, pad_x)), mode=pad_method)
382+
383+
# FFT and shift of data
384+
fft_data = fft2(padded_data, axes=(-2, -1), overwrite_x=True)
385+
fft_data_shifted = fftshift(fft_data)
386+
387+
# Setup various values for the filter
388+
_, height, width = sinogram.shape
389+
390+
height1 = height + 2 * pad_y
391+
width1 = width + 2 * pad_x
392+
393+
# setting grid/block parameters
394+
block_x = 128
395+
block_dims = (block_x, 1, 1)
396+
grid_x = (width1 + block_x - 1) // block_x
397+
grid_y = height1
398+
grid_dims = (grid_x, grid_y, 1)
399+
params = (fft_data_shifted, fft_data, width1, height1, uvalue, nvalue, vvalue)
400+
401+
raven_module = load_cuda_module("raven_filter")
402+
raven_filt = raven_module.get_function("raven_filter")
403+
404+
raven_filt(grid_dims, block_dims, params)
405+
406+
# raven_filt already doing ifftshifting
407+
# fft_data = ifftshift(fft_data_shifted)
408+
sinogram = ifft2(fft_data, axes=(-2, -1), overwrite_x=True)
386409

387410
return sinogram
388411

tests/test_prep/test_stripe.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
from cupy.cuda import nvtx
44
import numpy as np
55
import pytest
6+
import pyfftw
7+
import pyfftw.interfaces.numpy_fft as fft
68
from httomolibgpu.prep.normalize import normalize
79
from httomolibgpu.prep.stripe import (
810
remove_stripe_based_sorting,
911
remove_stripe_ti,
1012
remove_all_stripe,
13+
raven_filter,
1114
)
1215
from numpy.testing import assert_allclose
1316

@@ -51,7 +54,6 @@ def test_remove_stripe_ti_on_data(data, flats, darks):
5154
# np.median(corrected_data), np.median(corrected_host_data), rtol=1e-6
5255
# )
5356

54-
5557
def test_stripe_removal_sorting_cupy(data, flats, darks):
5658
# --- testing the CuPy port of TomoPy's implementation ---#
5759
data = normalize(data, flats, darks, cutoff=10, minus_log=True)
@@ -66,7 +68,6 @@ def test_stripe_removal_sorting_cupy(data, flats, darks):
6668
assert corrected_data.dtype == np.float32
6769
assert corrected_data.flags.c_contiguous
6870

69-
7071
@pytest.mark.perf
7172
def test_stripe_removal_sorting_cupy_performance(ensure_clean_memory):
7273
data_host = (
@@ -116,6 +117,29 @@ def test_remove_stripe_ti_performance(ensure_clean_memory):
116117

117118
assert "performance in ms" == duration_ms
118119

120+
@pytest.mark.perf
121+
def test_raven_filter_performance(ensure_clean_memory):
122+
data_host = (
123+
np.random.random_sample(size=(1801, 5, 2560)).astype(np.float32) * 2.0 + 0.001
124+
)
125+
data = cp.asarray(data_host, dtype=np.float32)
126+
127+
# do a cold run first
128+
raven_filter(cp.copy(data))
129+
130+
dev = cp.cuda.Device()
131+
dev.synchronize()
132+
133+
start = time.perf_counter_ns()
134+
nvtx.RangePush("Core")
135+
for _ in range(10):
136+
# have to take copy, as data is modified in-place
137+
raven_filter(cp.copy(data))
138+
nvtx.RangePop()
139+
dev.synchronize()
140+
duration_ms = float(time.perf_counter_ns() - start) * 1e-6 / 10
141+
142+
assert "performance in ms" == duration_ms
119143

120144
def test_remove_all_stripe_on_data(data, flats, darks):
121145
# --- testing the CuPy implementation from TomoCupy ---#

0 commit comments

Comments
 (0)