|
30 | 30 |
|
31 | 31 | if cupy_run: |
32 | 32 | from cupyx.scipy.ndimage import median_filter, binary_dilation, uniform_filter1d |
33 | | - from httomolibgpu.misc.raven_filter import ( |
34 | | - raven_filter, |
35 | | - ) |
| 33 | + from cupyx.scipy.fft import fft2, ifft2, fftshift |
| 34 | + from httomolibgpu.cuda_kernels import load_cuda_module |
36 | 35 | else: |
37 | 36 | median_filter = Mock() |
38 | 37 | binary_dilation = Mock() |
39 | 38 | uniform_filter1d = Mock() |
40 | | - raven_filter = Mock() |
| 39 | + fft2 = Mock() |
| 40 | + ifft2 = Mock() |
| 41 | + fftshift = Mock() |
41 | 42 |
|
42 | 43 | from typing import Union |
43 | 44 |
|
@@ -363,26 +364,48 @@ def _rs_dead(sinogram, snr, size, matindex, norm=True): |
363 | 364 | sinogram = _rs_large(sinogram, snr, size, matindex) |
364 | 365 | return sinogram |
365 | 366 |
|
366 | | -def _raven_filter(sinogram, snr, size, matindex, vvalue=10, uvalue=10, nvalue=10 ): |
| 367 | +def raven_filter( |
| 368 | + sinogram, |
| 369 | + uvalue: int = 20, |
| 370 | + nvalue: int = 4, |
| 371 | + vvalue: int = 2, |
| 372 | + pad_y: int = 20, |
| 373 | + pad_x: int = 20, |
| 374 | + pad_method: str = "edge"): |
367 | 375 | """ |
368 | 376 | Raven filter |
369 | 377 | """ |
370 | | - padding = 2 |
371 | | - (nrow, ncol) = sinogram.shape |
372 | | - width1 = nrow + 2 * padding #sino_shape[1] + 2 * self.pad |
373 | | - height1 = ncol + 2 * padding #sino_shape[0] + 2 * self.pad |
374 | | - |
375 | | - # Create filter |
376 | | - centerx = np.ceil(width1 / 2.0) - 1.0 |
377 | | - centery = np.int16(np.ceil(height1 / 2.0) - 1) |
378 | | - row1 = centery - vvalue |
379 | | - row2 = centery + vvalue + 1 |
380 | | - listx = np.arange(width1) - centerx |
381 | | - filtershape = 1.0 / (1.0 + np.power(listx / uvalue, 2 * nvalue)) |
382 | | - filtershapepad2d = np.zeros((self.row2 - self.row1, filtershape.size)) |
383 | | - filtershapepad2d[:] = np.float64(filtershape) |
384 | | - filtercomplex = filtershapepad2d + filtershapepad2d * 1j |
385 | 378 |
|
| 379 | + # Padding of the data |
| 380 | + padded_data = cp.pad(sinogram, ((0, 0), (pad_y, pad_y), (pad_x, pad_x)), mode=pad_method) |
| 381 | + # padded_data = cp.pad(sinogram, ((pad_y, pad_y), (pad_x, pad_x)), mode=pad_method) |
| 382 | + |
| 383 | + # FFT and shift of data |
| 384 | + fft_data = fft2(padded_data, axes=(-2, -1), overwrite_x=True) |
| 385 | + fft_data_shifted = fftshift(fft_data) |
| 386 | + |
| 387 | + # Setup various values for the filter |
| 388 | + _, height, width = sinogram.shape |
| 389 | + |
| 390 | + height1 = height + 2 * pad_y |
| 391 | + width1 = width + 2 * pad_x |
| 392 | + |
| 393 | + # setting grid/block parameters |
| 394 | + block_x = 128 |
| 395 | + block_dims = (block_x, 1, 1) |
| 396 | + grid_x = (width1 + block_x - 1) // block_x |
| 397 | + grid_y = height1 |
| 398 | + grid_dims = (grid_x, grid_y, 1) |
| 399 | + params = (fft_data_shifted, fft_data, width1, height1, uvalue, nvalue, vvalue) |
| 400 | + |
| 401 | + raven_module = load_cuda_module("raven_filter") |
| 402 | + raven_filt = raven_module.get_function("raven_filter") |
| 403 | + |
| 404 | + raven_filt(grid_dims, block_dims, params) |
| 405 | + |
| 406 | + # raven_filt already doing ifftshifting |
| 407 | + # fft_data = ifftshift(fft_data_shifted) |
| 408 | + sinogram = ifft2(fft_data, axes=(-2, -1), overwrite_x=True) |
386 | 409 |
|
387 | 410 | return sinogram |
388 | 411 |
|
|
0 commit comments