66
77PYNFFT_AVAILABLE = True
88try :
9- import pynfft
9+ import pyNFFT3 as pynfft3
1010except ImportError :
11- PYNUFFT_AVAILABLE = False
11+ PYNFFT_AVAILABLE = False
1212
1313
1414def get_fourier_matrix (ktraj , shape , ndim , do_ifft = False ):
@@ -24,28 +24,27 @@ def get_fourier_matrix(ktraj, shape, ndim, do_ifft=False):
2424 return matrix / np .sqrt (n )
2525
2626
27- class RawPyNFFT :
28- """Implementation of the NUDFT using numpy ."""
27+ class RawPyNFFT3 :
28+ """Binding for the pyNFFT3 package ."""
2929
3030 def __init__ (self , samples , shape ):
3131 self .samples = samples
3232 self .shape = shape
33- self .ndim = len (shape )
34- self .plan = pynfft .NFFT (N = shape , M = len (samples ))
33+ self .plan = pynfft3 .NFFT (N = np .array (shape , dtype = "int32" ), M = len (samples ))
3534 self .plan .x = self .samples
36- self .plan .precompute ()
37- self .shape = shape
3835
3936 def op (self , coeffs , image ):
4037 """Compute the forward NUDFT."""
41- self .plan .f_hat = image
42- np .copyto (coeffs , self .plan .trafo ())
38+ self .plan .fhat = image .ravel ()
39+ self .plan .trafo ()
40+ np .copyto (coeffs , self .plan .f .reshape (- 1 ))
4341 return coeffs
4442
4543 def adj_op (self , coeffs , image ):
4644 """Compute the adjoint NUDFT."""
47- self .plan .f = coeffs
48- np .copyto (image , self .plan .adjoint ())
45+ self .plan .f = coeffs .ravel ()
46+ self .plan .adjoint ()
47+ np .copyto (image , self .plan .fhat .reshape (self .shape ))
4948 return image
5049
5150
@@ -71,4 +70,4 @@ def __init__(
7170 density = density ,
7271 raw_op = None , # is set later, after normalizing samples.
7372 )
74- self .raw_op = RawPyNFFT (self .samples , shape )
73+ self .raw_op = RawPyNFFT3 (self .samples , shape )
0 commit comments