diff --git a/torch_radon/__init__.py b/torch_radon/__init__.py index fd038e9..99984d0 100644 --- a/torch_radon/__init__.py +++ b/torch_radon/__init__.py @@ -93,14 +93,14 @@ def filter_sinogram(self, sinogram, filter_name="ramp"): padded_sinogram = F.pad(sinogram.float(), (0, pad, 0, 0)) # TODO should be possible to use onesided=True saving memory and time - sino_fft = torch.rfft(padded_sinogram, 1, normalized=True, onesided=False) + sino_fft = torch.fft.fft(padded_sinogram) # get filter and apply f = self.fourier_filters.get(padded_size, filter_name, sinogram.device) - filtered_sino_fft = sino_fft * f + filtered_sino_fft = sino_fft * f.squeeze(2).unsqueeze(1) # Inverse fft - filtered_sinogram = torch.irfft(filtered_sino_fft, 1, normalized=True, onesided=False) + filtered_sinogram = torch.real(torch.fft.ifft(filtered_sino_fft)) # pad removal and rescaling filtered_sinogram = filtered_sinogram[:, :, :-pad] * (np.pi / (2 * n_angles))