Skip to content

Commit 635fe57

Browse files
committed
STY: Add type hints to PET-related classes and functions
Add type hints to PET-related classes and functions. Notably, use positional arguments instead of keyword arguments when instantiating a `namedtuple`.
1 parent cb13535 commit 635fe57

File tree

3 files changed

+31
-22
lines changed

3 files changed

+31
-22
lines changed

src/nifreeze/data/pet.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
import json
2828
from collections import namedtuple
2929
from pathlib import Path
30+
from typing import Any, Tuple
3031

3132
import attrs
3233
import h5py
3334
import nibabel as nb
3435
import numpy as np
3536
from nibabel.spatialimages import SpatialImage
3637
from nitransforms.linear import Affine
38+
from typing_extensions import Self
3739

3840
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
3941
from nifreeze.utils.ndimage import load_api
@@ -79,7 +81,9 @@ def __getitem__(
7981
"""
8082
return super().__getitem__(idx)
8183

82-
def lofo_split(self, index):
84+
def lofo_split(
85+
self, index: int
86+
) -> Tuple[Tuple[np.ndarray, np.ndarray | None], Tuple[np.ndarray, np.ndarray | None]]:
8387
"""
8488
Leave-one-frame-out (LOFO) for PET data.
8589
@@ -118,11 +122,10 @@ def lofo_split(self, index):
118122

119123
return (train_data, train_timings), (test_data, test_timing)
120124

121-
def set_transform(self, index, affine, order=3):
125+
def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
122126
"""Set an affine, and update data object and gradients."""
123-
reference = namedtuple("ImageGrid", ("shape", "affine"))(
124-
shape=self.dataobj.shape[:3], affine=self.affine
125-
)
127+
ImageGrid = namedtuple("ImageGrid", ("shape", "affine"))
128+
reference = ImageGrid(self.dataobj.shape[:3], self.affine)
126129
xform = Affine(matrix=affine, reference=reference)
127130

128131
if not Path(self._filepath).exists():
@@ -147,7 +150,9 @@ def set_transform(self, index, affine, order=3):
147150

148151
self.motion_affines[index] = xform
149152

150-
def to_filename(self, filename, compression=None, compression_opts=None):
153+
def to_filename(
154+
self, filename: Path | str, compression: str | None = None, compression_opts: Any = None
155+
) -> None:
151156
"""Write an HDF5 file to disk."""
152157
filename = Path(filename)
153158
if not filename.name.endswith(".h5"):
@@ -178,21 +183,23 @@ def to_nifti(self, filename, *_):
178183
nii.to_filename(filename)
179184

180185
@classmethod
181-
def from_filename(cls, filename):
186+
def from_filename(cls, filename: Path | str) -> Self:
182187
"""Read an HDF5 file from disk."""
183188
with h5py.File(filename, "r") as in_file:
184189
root = in_file["/0"]
185190
data = {k: np.asanyarray(v) for k, v in root.items() if not k.startswith("_")}
186191
return cls(**data)
187192

188193
@classmethod
189-
def load(cls, filename, json_file, brainmask_file=None):
194+
def load(
195+
cls, filename: Path | str, json_file: Path | str, brainmask_file: Path | str | None = None
196+
) -> Self:
190197
"""Load PET data."""
191198
filename = Path(filename)
192199
if filename.name.endswith(".h5"):
193200
return cls.from_filename(filename)
194201

195-
img = nb.load(filename)
202+
img = load_api(filename, SpatialImage)
196203
retval = cls(
197204
dataobj=img.get_fdata(dtype="float32"),
198205
affine=img.affine,
@@ -212,7 +219,7 @@ def load(cls, filename, json_file, brainmask_file=None):
212219
assert len(retval.midframe) == retval.dataobj.shape[-1]
213220

214221
if brainmask_file:
215-
mask = nb.load(brainmask_file)
222+
mask = load_api(brainmask_file, SpatialImage)
216223
retval.brainmask = np.asanyarray(mask.dataobj)
217224

218225
return retval

src/nifreeze/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,11 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
222222
class PETMotionEstimator:
223223
"""Estimates motion within PET imaging data aligned with generic Estimator workflow."""
224224

225-
def __init__(self, align_kwargs=None, strategy="lofo"):
225+
def __init__(self, align_kwargs: dict | None = None, strategy: str = "lofo"):
226226
self.align_kwargs = align_kwargs or {}
227227
self.strategy = strategy
228228

229-
def run(self, pet_dataset, omp_nthreads=None):
229+
def run(self, pet_dataset: PET, omp_nthreads: int | None = None) -> list:
230230
n_frames = len(pet_dataset)
231231
frame_indices = np.arange(n_frames)
232232

src/nifreeze/model/pet.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"""Models for nuclear imaging."""
2424

2525
from os import cpu_count
26+
from typing import Union
2627

2728
import nibabel as nb
2829
import numpy as np
@@ -31,6 +32,7 @@
3132
from scipy.interpolate import BSpline
3233
from scipy.sparse.linalg import cg
3334

35+
from nifreeze.data.pet import PET
3436
from nifreeze.model.base import BaseModel
3537

3638
DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
@@ -54,13 +56,13 @@ class PETModel(BaseModel):
5456

5557
def __init__(
5658
self,
57-
dataset,
58-
timepoints=None,
59-
xlim=None,
60-
n_ctrl=None,
61-
order=3,
62-
smooth_fwhm=10,
63-
thresh_pct=20,
59+
dataset: PET,
60+
timepoints: list | np.ndarray | None = None,
61+
xlim: float | None = None,
62+
n_ctrl: int | None = None,
63+
order: int = 3,
64+
smooth_fwhm: float = 10.0,
65+
thresh_pct: float = 20.0,
6466
**kwargs,
6567
):
6668
"""
@@ -105,7 +107,7 @@ def __init__(
105107
self._mask = None
106108

107109
@property
108-
def is_fitted(self):
110+
def is_fitted(self) -> bool:
109111
return self._locked_fit is not None
110112

111113
def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
@@ -117,7 +119,7 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
117119
if index is not None:
118120
raise NotImplementedError("Fitting with held-out data is not supported")
119121
timepoints = kwargs.get("timepoints", None) or self._x
120-
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl
122+
x = np.asarray((np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl)
121123

122124
data = self._dataset.dataobj
123125
brainmask = self._dataset.brainmask
@@ -148,7 +150,7 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
148150

149151
return n_jobs
150152

151-
def fit_predict(self, index: int | None = None, **kwargs):
153+
def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]:
152154
"""Return the corrected volume using B-spline interpolation."""
153155

154156
# Fit the BSpline basis on all data

0 commit comments

Comments
 (0)