Skip to content

Commit a2682fc

Browse files
authored
feat: move to pynfft3. (#274)
1 parent 782224b commit a2682fc

File tree

2 files changed

+13
-14
lines changed

2 files changed

+13
-14
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ gpunufft = ["gpuNUFFT>=0.9.0", "cupy-cuda12x"]
1515
torchkbnufft = ["torchkbnufft", "cupy-cuda12x"]
1616
cufinufft = ["cufinufft<2.3", "cupy-cuda12x"]
1717
finufft = ["finufft"]
18-
pynfft = ["pynfft2>=1.4.3; python_version < '3.12'", "numpy>=2.0.0; python_version < '3.12'"]
18+
pynfft = ["pynfft3"]
1919
pynufft = ["pynufft"]
2020
extra = ["pymapvbvd", "scikit-image", "scikit-learn", "pywavelets"]
2121
autodiff = ["torch"]

src/mrinufft/operators/interfaces/nfft.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
PYNFFT_AVAILABLE = True
88
try:
9-
import pynfft
9+
import pyNFFT3 as pynfft3
1010
except ImportError:
11-
PYNUFFT_AVAILABLE = False
11+
PYNFFT_AVAILABLE = False
1212

1313

1414
def get_fourier_matrix(ktraj, shape, ndim, do_ifft=False):
@@ -24,28 +24,27 @@ def get_fourier_matrix(ktraj, shape, ndim, do_ifft=False):
2424
return matrix / np.sqrt(n)
2525

2626

27-
class RawPyNFFT:
28-
"""Implementation of the NUDFT using numpy."""
27+
class RawPyNFFT3:
28+
"""Binding for the pyNFFT3 package."""
2929

3030
def __init__(self, samples, shape):
3131
self.samples = samples
3232
self.shape = shape
33-
self.ndim = len(shape)
34-
self.plan = pynfft.NFFT(N=shape, M=len(samples))
33+
self.plan = pynfft3.NFFT(N=np.array(shape, dtype="int32"), M=len(samples))
3534
self.plan.x = self.samples
36-
self.plan.precompute()
37-
self.shape = shape
3835

3936
def op(self, coeffs, image):
4037
"""Compute the forward NUDFT."""
41-
self.plan.f_hat = image
42-
np.copyto(coeffs, self.plan.trafo())
38+
self.plan.fhat = image.ravel()
39+
self.plan.trafo()
40+
np.copyto(coeffs, self.plan.f.reshape(-1))
4341
return coeffs
4442

4543
def adj_op(self, coeffs, image):
4644
"""Compute the adjoint NUDFT."""
47-
self.plan.f = coeffs
48-
np.copyto(image, self.plan.adjoint())
45+
self.plan.f = coeffs.ravel()
46+
self.plan.adjoint()
47+
np.copyto(image, self.plan.fhat.reshape(self.shape))
4948
return image
5049

5150

@@ -71,4 +70,4 @@ def __init__(
7170
density=density,
7271
raw_op=None, # is set later, after normalizing samples.
7372
)
74-
self.raw_op = RawPyNFFT(self.samples, shape)
73+
self.raw_op = RawPyNFFT3(self.samples, shape)

0 commit comments

Comments
 (0)