Skip to content

Commit f84fc78

Browse files
committed
Fix sinogram stack handle
1 parent 2979406 commit f84fc78

File tree

6 files changed

+149
-191
lines changed

6 files changed

+149
-191
lines changed

docs/source/examples/raven_filter_example.py

Lines changed: 29 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pyfftw
66
import pyfftw.interfaces.numpy_fft as fft
77
import httomolibgpu
8-
from httomolibgpu.misc.raven_filter import raven_filter
8+
from httomolibgpu.prep.stripe import raven_filter
99

1010
import matplotlib.pyplot as plt
1111
import time
@@ -23,63 +23,70 @@
2323

2424
sino_shape = sinogram.shape
2525

26-
print("The shape of the sinogram is {}".format(cp.shape(sinogram)))
26+
sinogram_stack = cp.stack([sinogram] * 20, axis=1)
27+
28+
print("The shape of the sinogram stack is {}".format(cp.shape(sinogram_stack)))
2729

2830
# Parameters
2931
v0 = 2
3032
n = 4
3133
u0 = 20
3234

3335
# Make a numpy copy
34-
sinogram_padded = np.pad(sinogram.get(), 20, "edge")
36+
sinogram_padded = np.pad(sinogram_stack.get(), [(20, 20), (0, 0), (20, 20)], "edge")
3537

3638
start_time = time.time()
3739
# GPU filter
38-
sinogram_gpu_filter = raven_filter(sinogram, u0, n, v0)
40+
sinogram_gpu_filter = raven_filter(sinogram_stack, u0, n, v0)
3941
print("--- %s seconds ---" % (time.time() - start_time))
4042

4143
start_time = time.time()
4244

4345
# Size
44-
width1 = sino_shape[1] + 2 * 20
45-
height1 = sino_shape[0] + 2 * 20
46+
height, images, width = sinogram_padded.shape
4647

4748
# Generate filter function
48-
centerx = np.ceil(width1 / 2.0) - 1.0
49-
centery = np.int16(np.ceil(height1 / 2.0) - 1)
49+
centerx = np.ceil(width / 2.0) - 1.0
50+
centery = np.int16(np.ceil(height / 2.0) - 1)
5051
row1 = centery - v0
5152
row2 = centery + v0 + 1
52-
listx = np.arange(width1) - centerx
53+
listx = np.arange(width) - centerx
5354
filtershape = 1.0 / (1.0 + np.power(listx / u0, 2 * n))
5455
filtershapepad2d = np.zeros((row2 - row1, filtershape.size))
5556
filtershapepad2d[:] = np.float64(filtershape)
5657
filtercomplex = filtershapepad2d + filtershapepad2d * 1j
5758

5859
# Generate filter objects
59-
a = pyfftw.empty_aligned((height1, width1), dtype='complex128', n=16)
60-
b = pyfftw.empty_aligned((height1, width1), dtype='complex128', n=16)
61-
c = pyfftw.empty_aligned((height1, width1), dtype='complex128', n=16)
62-
d = pyfftw.empty_aligned((height1, width1), dtype='complex128', n=16)
63-
fft_object = pyfftw.FFTW(a, b, axes=(0, 1))
64-
ifft_object = pyfftw.FFTW(c, d, axes=(0, 1), direction='FFTW_BACKWARD')
60+
a = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
61+
b = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
62+
c = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
63+
d = pyfftw.empty_aligned((height, images, width), dtype='complex128', n=16)
64+
fft_object = pyfftw.FFTW(a, b, axes=(0, 2))
65+
ifft_object = pyfftw.FFTW(c, d, axes=(0, 2), direction='FFTW_BACKWARD')
66+
67+
sino = fft.fftshift(fft_object(sinogram_padded), axes=(0, 2))
68+
for m in range(sino.shape[1]):
69+
sino[row1:row2, m] = sino[row1:row2, m] * filtercomplex
70+
sino = ifft_object(fft.ifftshift(sino, axes=(0, 2)))
6571

66-
sino = fft.fftshift(fft_object(sinogram_padded))
67-
sino[row1:row2] = sino[row1:row2] * filtercomplex
68-
sino = ifft_object(fft.ifftshift(sino))
72+
# Remove padding
73+
sino = sino[20:height-20, :, 20:width-20]
6974

7075
print("--- %s seconds ---" % (time.time() - start_time))
7176

7277
#subplot(r,c) provide the no. of rows and columns
7378
f, axarr = plt.subplots(2,2)
7479

80+
sino_index = 10
81+
7582
# use the created array to output your multiple images. In this case I have stacked 4 images vertically
76-
axarr[0, 0].imshow(sinogram_padded)
83+
axarr[0, 0].imshow(sinogram_stack.get()[:, sino_index, :])
7784
axarr[0, 0].set_title('Original sinogram')
78-
axarr[0, 1].imshow(sinogram_padded - sinogram_gpu_filter.get().real)
85+
axarr[0, 1].imshow(sinogram_stack.get()[:, sino_index, :] - sinogram_gpu_filter.get().real[:, sino_index, :])
7986
axarr[0, 1].set_title('Difference of original and GPU filtered')
80-
axarr[1, 0].imshow(sinogram_padded - sino.real)
87+
axarr[1, 0].imshow(sinogram_stack.get()[:, sino_index, :] - sino.real[:, sino_index, :])
8188
axarr[1, 0].set_title('Difference of original and CPU filtered')
82-
axarr[1, 1].imshow(sinogram_gpu_filter.get().real - sino.real)
89+
axarr[1, 1].imshow(sinogram_gpu_filter.get().real[:, sino_index, :] - sino.real[:, sino_index, :])
8390
axarr[1, 1].set_title('Difference of GPU and CPU filtered')
8491

8592
plt.show()
Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
#include <cupy/complex.cuh>
22

33
extern "C" __global__ void
4-
raven_filter(complex<float> *input, complex<float> *output, int width1, int height1, int u0, int n, int v0) {
4+
raven_filter(
5+
complex<float> *input,
6+
complex<float> *output,
7+
int width, int images, int height,
8+
int u0, int n, int v0) {
59

6-
int centerx = width1 / 2;
7-
int centery = height1 / 2;
10+
const int px = threadIdx.x + blockIdx.x * blockDim.x;
11+
const int py = threadIdx.y + blockIdx.y * blockDim.y;
12+
const int pz = threadIdx.z + blockIdx.z * blockDim.z;
813

9-
int px = threadIdx.x + blockIdx.x * blockDim.x;
10-
int py = threadIdx.y + blockIdx.y * blockDim.y;
11-
12-
if (px >= width1)
13-
return;
14-
if (py >= height1)
14+
if (px >= width || py >= images || pz >= height)
1515
return;
1616

17-
complex<float> value = input[py * width1 + px];
18-
if( py >= (centery - v0) && py < (centery + v0 + 1) ) {
17+
int centerx = width / 2;
18+
int centerz = height / 2;
19+
20+
complex<float> value = input[pz * width * images + py * width + px];
21+
if( pz >= (centerz - v0) && pz < (centerz + v0 + 1) ) {
1922

2023
// +1 needed to match with CPU implementation
2124
float base = float(px - centerx + 1) / u0;
@@ -28,10 +31,10 @@ raven_filter(complex<float> *input, complex<float> *output, int width1, int heig
2831
}
2932

3033
// ifftshifting positions
31-
int xshift = (width1 + 1) / 2;
32-
int yshift = (height1 + 1) / 2;
33-
int outX = (px + xshift) % width1;
34-
int outY = (py + yshift) % height1;
34+
int xshift = (width + 1) / 2;
35+
int zshift = (height + 1) / 2;
36+
int outX = (px + xshift) % width;
37+
int outZ = (pz + zshift) % height;
3538

36-
output[outY * width1 + outX] = value;
39+
output[outZ * width * images + py * width + outX] = value;
3740
}

httomolibgpu/misc/raven_filter.py

Lines changed: 0 additions & 135 deletions
This file was deleted.

httomolibgpu/prep/stripe.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131
if cupy_run:
3232
from cupyx.scipy.ndimage import median_filter, binary_dilation, uniform_filter1d
33-
from cupyx.scipy.fft import fft2, ifft2, fftshift
33+
from cupyx.scipy.fft import fft2, ifft2, fftshift, ifftshift
3434
from httomolibgpu.cuda_kernels import load_cuda_module
3535
else:
3636
median_filter = Mock()
@@ -39,13 +39,15 @@
3939
fft2 = Mock()
4040
ifft2 = Mock()
4141
fftshift = Mock()
42+
ifftshift = Mock()
4243

4344
from typing import Union
4445

4546
__all__ = [
4647
"remove_stripe_based_sorting",
4748
"remove_stripe_ti",
4849
"remove_all_stripe",
50+
"raven_filter",
4951
]
5052

5153

@@ -376,36 +378,36 @@ def raven_filter(
376378
Raven filter
377379
"""
378380

379-
# Padding of the data
380-
padded_data = cp.pad(sinogram, ((0, 0), (pad_y, pad_y), (pad_x, pad_x)), mode=pad_method)
381-
# padded_data = cp.pad(sinogram, ((pad_y, pad_y), (pad_x, pad_x)), mode=pad_method)
381+
# Padding of the sinogram
382+
sinogram = cp.pad(sinogram, ((pad_y, pad_y), (0, 0), (pad_x, pad_x)), mode=pad_method)
382383

383-
# FFT and shift of data
384-
fft_data = fft2(padded_data, axes=(-2, -1), overwrite_x=True)
385-
fft_data_shifted = fftshift(fft_data)
384+
# FFT and shift of sinogram
385+
fft_data = fft2(sinogram, axes=(0, 2), overwrite_x=True)
386+
fft_data_shifted = fftshift(fft_data, axes=(0, 2))
386387

387388
# Setup various values for the filter
388-
_, height, width = sinogram.shape
389-
390-
height1 = height + 2 * pad_y
391-
width1 = width + 2 * pad_x
389+
height, images, width = sinogram.shape
392390

393391
# setting grid/block parameters
394392
block_x = 128
395393
block_dims = (block_x, 1, 1)
396-
grid_x = (width1 + block_x - 1) // block_x
397-
grid_y = height1
398-
grid_dims = (grid_x, grid_y, 1)
399-
params = (fft_data_shifted, fft_data, width1, height1, uvalue, nvalue, vvalue)
394+
grid_x = (width + block_x - 1) // block_x
395+
grid_y = images
396+
grid_z = height
397+
grid_dims = (grid_x, grid_y, grid_z)
398+
params = (fft_data_shifted, fft_data, width, images, height, uvalue, nvalue, vvalue)
400399

401400
raven_module = load_cuda_module("raven_filter")
402401
raven_filt = raven_module.get_function("raven_filter")
403402

404403
raven_filt(grid_dims, block_dims, params)
405404

406405
# raven_filt already doing ifftshifting
407-
# fft_data = ifftshift(fft_data_shifted)
408-
sinogram = ifft2(fft_data, axes=(-2, -1), overwrite_x=True)
406+
# fft_data = ifftshift(fft_data_shifted, axes=(0, 2))
407+
sinogram = ifft2(fft_data, axes=(0, 2), overwrite_x=True)
408+
409+
# Removing padding
410+
sinogram = sinogram[pad_y:height-pad_y, :, pad_x:width-pad_x].real
409411

410412
return sinogram
411413

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ dev = [
5757
"toml",
5858
"imageio",
5959
"h5py",
60-
"pre-commit"
60+
"pre-commit",
61+
"pyfftw"
6162
]
6263

6364

0 commit comments

Comments
 (0)