1010
1111from abc import ABC , abstractmethod
1212from functools import partial
13-
13+ from typing import ClassVar , Callable
1414import numpy as np
15+ from numpy .typing import NDArray
1516
1617from mrinufft ._array_compat import with_numpy , with_numpy_cupy , AUTOGRAD_AVAILABLE
1718from mrinufft ._utils import auto_cast , power_method
1819from mrinufft .density import get_density
1920from mrinufft .extras import get_smaps
2021from mrinufft .operators .interfaces .utils import is_cuda_array , is_host_array
2122
22- if AUTOGRAD_AVAILABLE :
23- from mrinufft .operators .autodiff import MRINufftAutoGrad
24-
25-
2623# Mapping between numpy float and complex types.
2724DTYPE_R2C = {"float32" : "complex64" , "float64" : "complex128" }
2825
@@ -122,6 +119,9 @@ class FourierOperatorBase(ABC):
122119 _grad_wrt_data = False
123120 _grad_wrt_traj = False
124121
122+ backend : ClassVar [str ]
123+ available : ClassVar [bool ]
124+
125125 def __init__ (self ):
126126 if not self .available :
127127 raise RuntimeError (f"'{ self .backend } ' backend is not available." )
@@ -207,21 +207,21 @@ def adj_op(self, coeffs):
207207 """
208208 pass
209209
210- def data_consistency (self , image , obs_data ):
210+ def data_consistency (self , image_data , obs_data ):
211211 """Compute the gradient data consistency.
212212
213213 This is the naive implementation using adj_op(op(x)-y).
214214 Specific backend can (and should!) implement a more efficient version.
215215 """
216- return self .adj_op (self .op (image ) - obs_data )
216+ return self .adj_op (self .op (image_data ) - obs_data )
217217
218218 def with_off_resonance_correction (self , B , C , indices ):
219219 """Return a new operator with Off Resonnance Correction."""
220- from .. off_resonance import MRIFourierCorrected
220+ from .off_resonance import MRIFourierCorrected
221221
222222 return MRIFourierCorrected (self , B , C , indices )
223223
224- def compute_smaps (self , method = None ):
224+ def compute_smaps (self , method : NDArray | Callable | str | dict | None = None ):
225225 """Compute the sensitivity maps and set it.
226226
227227 Parameters
@@ -286,6 +286,8 @@ def make_autograd(self, wrt_data=True, wrt_traj=False):
286286 if not self .autograd_available :
287287 raise ValueError ("Backend does not support auto-differentiation." )
288288
289+ from mrinufft .operators .autodiff import MRINufftAutoGrad
290+
289291 return MRINufftAutoGrad (self , wrt_data = wrt_data , wrt_traj = wrt_traj )
290292
291293 def compute_density (self , method = None ):
@@ -401,9 +403,9 @@ def smaps(self):
401403 return self ._smaps
402404
403405 @smaps .setter
404- def smaps (self , smaps ):
405- self ._check_smaps_shape (smaps )
406- self ._smaps = smaps
406+ def smaps (self , new_smaps ):
407+ self ._check_smaps_shape (new_smaps )
408+ self ._smaps = new_smaps
407409
408410 def _check_smaps_shape (self , smaps ):
409411 """Check the shape of the sensitivity maps."""
@@ -421,22 +423,22 @@ def density(self):
421423 return self ._density
422424
423425 @density .setter
424- def density (self , density ):
425- if density is None :
426+ def density (self , new_density ):
427+ if new_density is None :
426428 self ._density = None
427- elif len (density ) != self .n_samples :
429+ elif len (new_density ) != self .n_samples :
428430 raise ValueError ("Density and samples should have the same length" )
429431 else :
430- self ._density = density
432+ self ._density = new_density
431433
432434 @property
433435 def dtype (self ):
434436 """Return floating precision of the operator."""
435437 return self ._dtype
436438
437439 @dtype .setter
438- def dtype (self , dtype ):
439- self ._dtype = np .dtype (dtype )
440+ def dtype (self , new_dtype ):
441+ self ._dtype = np .dtype (new_dtype )
440442
441443 @property
442444 def cpx_dtype (self ):
@@ -449,8 +451,8 @@ def samples(self):
449451 return self ._samples
450452
451453 @samples .setter
452- def samples (self , samples ):
453- self ._samples = samples
454+ def samples (self , new_samples ):
455+ self ._samples = new_samples
454456
455457 @property
456458 def n_samples (self ):
0 commit comments