Skip to content

Commit c56f1ba

Browse files
committed
refactor: consistent name in functions
1 parent 56601e6 commit c56f1ba

File tree

5 files changed

+39
-30
lines changed

5 files changed

+39
-30
lines changed

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,9 @@ live_mode = false
7676

7777
[tool.mypy]
7878
ignore_missing_imports = true
79+
80+
[tool.pyright]
81+
reportPossiblyUnboundVariable = false
82+
typeCheckingMode = "basic"
83+
reportOptionalSubscript = false
84+
reportOptionalMemberAccess = false

src/mrinufft/operators/base.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,16 @@
1010

1111
from abc import ABC, abstractmethod
1212
from functools import partial
13-
13+
from typing import ClassVar, Callable
1414
import numpy as np
15+
from numpy.typing import NDArray
1516

1617
from mrinufft._array_compat import with_numpy, with_numpy_cupy, AUTOGRAD_AVAILABLE
1718
from mrinufft._utils import auto_cast, power_method
1819
from mrinufft.density import get_density
1920
from mrinufft.extras import get_smaps
2021
from mrinufft.operators.interfaces.utils import is_cuda_array, is_host_array
2122

22-
if AUTOGRAD_AVAILABLE:
23-
from mrinufft.operators.autodiff import MRINufftAutoGrad
24-
25-
2623
# Mapping between numpy float and complex types.
2724
DTYPE_R2C = {"float32": "complex64", "float64": "complex128"}
2825

@@ -122,6 +119,9 @@ class FourierOperatorBase(ABC):
122119
_grad_wrt_data = False
123120
_grad_wrt_traj = False
124121

122+
backend: ClassVar[str]
123+
available: ClassVar[bool]
124+
125125
def __init__(self):
126126
if not self.available:
127127
raise RuntimeError(f"'{self.backend}' backend is not available.")
@@ -207,21 +207,21 @@ def adj_op(self, coeffs):
207207
"""
208208
pass
209209

210-
def data_consistency(self, image, obs_data):
210+
def data_consistency(self, image_data, obs_data):
211211
"""Compute the gradient data consistency.
212212
213213
This is the naive implementation using adj_op(op(x)-y).
214214
Specific backend can (and should!) implement a more efficient version.
215215
"""
216-
return self.adj_op(self.op(image) - obs_data)
216+
return self.adj_op(self.op(image_data) - obs_data)
217217

218218
def with_off_resonance_correction(self, B, C, indices):
219219
"""Return a new operator with Off Resonnance Correction."""
220-
from ..off_resonance import MRIFourierCorrected
220+
from .off_resonance import MRIFourierCorrected
221221

222222
return MRIFourierCorrected(self, B, C, indices)
223223

224-
def compute_smaps(self, method=None):
224+
def compute_smaps(self, method: NDArray | Callable | str | dict | None = None):
225225
"""Compute the sensitivity maps and set it.
226226
227227
Parameters
@@ -286,6 +286,8 @@ def make_autograd(self, wrt_data=True, wrt_traj=False):
286286
if not self.autograd_available:
287287
raise ValueError("Backend does not support auto-differentiation.")
288288

289+
from mrinufft.operators.autodiff import MRINufftAutoGrad
290+
289291
return MRINufftAutoGrad(self, wrt_data=wrt_data, wrt_traj=wrt_traj)
290292

291293
def compute_density(self, method=None):
@@ -401,9 +403,9 @@ def smaps(self):
401403
return self._smaps
402404

403405
@smaps.setter
404-
def smaps(self, smaps):
405-
self._check_smaps_shape(smaps)
406-
self._smaps = smaps
406+
def smaps(self, new_smaps):
407+
self._check_smaps_shape(new_smaps)
408+
self._smaps = new_smaps
407409

408410
def _check_smaps_shape(self, smaps):
409411
"""Check the shape of the sensitivity maps."""
@@ -421,22 +423,22 @@ def density(self):
421423
return self._density
422424

423425
@density.setter
424-
def density(self, density):
425-
if density is None:
426+
def density(self, new_density):
427+
if new_density is None:
426428
self._density = None
427-
elif len(density) != self.n_samples:
429+
elif len(new_density) != self.n_samples:
428430
raise ValueError("Density and samples should have the same length")
429431
else:
430-
self._density = density
432+
self._density = new_density
431433

432434
@property
433435
def dtype(self):
434436
"""Return floating precision of the operator."""
435437
return self._dtype
436438

437439
@dtype.setter
438-
def dtype(self, dtype):
439-
self._dtype = np.dtype(dtype)
440+
def dtype(self, new_dtype):
441+
self._dtype = np.dtype(new_dtype)
440442

441443
@property
442444
def cpx_dtype(self):
@@ -449,8 +451,8 @@ def samples(self):
449451
return self._samples
450452

451453
@samples.setter
452-
def samples(self, samples):
453-
self._samples = samples
454+
def samples(self, new_samples):
455+
self._samples = new_samples
454456

455457
@property
456458
def n_samples(self):

src/mrinufft/operators/interfaces/cufinufft.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
except ImportError:
2727
CUFINUFFT_AVAILABLE = False
2828

29-
3029
OPTS_FIELD_DECODE = {
3130
"gpu_method": {1: "nonuniform pts driven", 2: "shared memory"},
3231
"gpu_sort": {0: "no sort (GM)", 1: "sort (GM-sort)"},
@@ -269,10 +268,12 @@ def smaps(self, new_smaps):
269268
self._smaps = new_smaps
270269

271270
@FourierOperatorBase.samples.setter
272-
def samples(self, samples):
271+
def samples(self, new_samples):
273272
"""Update the plans when changing the samples."""
274273
self._samples = np.asfortranarray(
275-
proper_trajectory(samples, normalize="pi").astype(np.float32, copy=False)
274+
proper_trajectory(new_samples, normalize="pi").astype(
275+
np.float32, copy=False
276+
)
276277
)
277278
for typ in [1, 2, "grad"]:
278279
if typ == "grad" and not self._grad_wrt_traj:

src/mrinufft/operators/interfaces/gpunufft.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ def smaps(self, new_smaps):
532532
self.raw_op.set_smaps(smaps=new_smaps)
533533

534534
@FourierOperatorBase.samples.setter
535-
def samples(self, samples):
535+
def samples(self, new_samples):
536536
"""Set the samples for the Fourier Operator.
537537
538538
Parameters
@@ -541,7 +541,7 @@ def samples(self, samples):
541541
The samples for the Fourier Operator.
542542
"""
543543
self._samples = proper_trajectory(
544-
samples.astype(np.float32, copy=False), normalize="unit"
544+
new_samples.astype(np.float32, copy=False), normalize="unit"
545545
)
546546
# TODO: gpuNUFFT needs to sort the points twice in this case.
547547
# It could help to have access to directly dorted arrays from gpuNUFFT.
@@ -552,19 +552,19 @@ def samples(self, samples):
552552
)
553553

554554
@FourierOperatorBase.density.setter
555-
def density(self, density):
555+
def density(self, new_density):
556556
"""Set the density for the Fourier Operator.
557557
558558
Parameters
559559
----------
560560
density: np.ndarray
561561
The density for the Fourier Operator.
562562
"""
563-
self._density = density
563+
self._density = new_density
564564
if hasattr(self, "raw_op"): # edge case for init
565565
self.raw_op.set_pts(
566566
self._samples,
567-
density=density,
567+
density=new_density,
568568
)
569569

570570
@classmethod

src/mrinufft/operators/interfaces/tfnufft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def norm_factor(self):
134134
return np.sqrt(np.prod(self.shape) * 2 ** len(self.shape))
135135

136136
@with_tensorflow
137-
def data_consistency(self, data, obs_data):
137+
def data_consistency(self, image_data, obs_data):
138138
"""Compute the data consistency.
139139
140140
Parameters
@@ -149,7 +149,7 @@ def data_consistency(self, data, obs_data):
149149
Tensor
150150
The data consistency error in image space.
151151
"""
152-
return self.adj_op(self.op(data) - obs_data)
152+
return self.adj_op(self.op(image_data) - obs_data)
153153

154154
@classmethod
155155
def pipe(

0 commit comments

Comments
 (0)