Skip to content

Commit 119ec4e

Browse files
committed
STY: Add type hints to PET-related classes and functions
Add type hints to PET-related classes and functions. Notably, use positional arguments instead of keyword arguments when instantiating a `namedtuple`.
1 parent bb3d129 commit 119ec4e

File tree

3 files changed

+29
-21
lines changed

3 files changed

+29
-21
lines changed

src/nifreeze/data/pet.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,15 @@
2727
import json
2828
from collections import namedtuple
2929
from pathlib import Path
30+
from typing import Any, Tuple
3031

3132
import attrs
3233
import h5py
3334
import nibabel as nb
3435
import numpy as np
3536
from nibabel.spatialimages import SpatialImage
3637
from nitransforms.linear import Affine
38+
from typing_extensions import Self
3739

3840
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
3941
from nifreeze.utils.ndimage import load_api
@@ -79,7 +81,9 @@ def __getitem__(
7981
"""
8082
return super().__getitem__(idx)
8183

82-
def lofo_split(self, index):
84+
def lofo_split(
85+
self, index: int
86+
) -> Tuple[Tuple[np.ndarray, np.ndarray | None], Tuple[np.ndarray, np.ndarray | None]]:
8387
"""
8488
Leave-one-frame-out (LOFO) for PET data.
8589
@@ -118,11 +122,10 @@ def lofo_split(self, index):
118122

119123
return (train_data, train_timings), (test_data, test_timing)
120124

121-
def set_transform(self, index, affine, order=3):
125+
def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
122126
"""Set an affine, and update data object and gradients."""
123-
reference = namedtuple("ImageGrid", ("shape", "affine"))(
124-
shape=self.dataobj.shape[:3], affine=self.affine
125-
)
127+
ImageGrid = namedtuple("ImageGrid", ("shape", "affine"))
128+
reference = ImageGrid(self.dataobj.shape[:3], self.affine)
126129
xform = Affine(matrix=affine, reference=reference)
127130

128131
if not Path(self._filepath).exists():
@@ -147,7 +150,9 @@ def set_transform(self, index, affine, order=3):
147150

148151
self.motion_affines[index] = xform
149152

150-
def to_filename(self, filename, compression=None, compression_opts=None):
153+
def to_filename(
154+
self, filename: Path | str, compression: str | None = None, compression_opts: Any = None
155+
) -> None:
151156
"""Write an HDF5 file to disk."""
152157
filename = Path(filename)
153158
if not filename.name.endswith(".h5"):
@@ -178,21 +183,23 @@ def to_nifti(self, filename, *_):
178183
nii.to_filename(filename)
179184

180185
@classmethod
181-
def from_filename(cls, filename):
186+
def from_filename(cls, filename: Path | str) -> Self:
182187
"""Read an HDF5 file from disk."""
183188
with h5py.File(filename, "r") as in_file:
184189
root = in_file["/0"]
185190
data = {k: np.asanyarray(v) for k, v in root.items() if not k.startswith("_")}
186191
return cls(**data)
187192

188193
@classmethod
189-
def load(cls, filename, json_file, brainmask_file=None):
194+
def load(
195+
cls, filename: Path | str, json_file: Path | str, brainmask_file: Path | str | None = None
196+
) -> Self:
190197
"""Load PET data."""
191198
filename = Path(filename)
192199
if filename.name.endswith(".h5"):
193200
return cls.from_filename(filename)
194201

195-
img = nb.load(filename)
202+
img = load_api(filename, SpatialImage)
196203
retval = cls(
197204
dataobj=img.get_fdata(dtype="float32"),
198205
affine=img.affine,
@@ -212,7 +219,7 @@ def load(cls, filename, json_file, brainmask_file=None):
212219
assert len(retval.midframe) == retval.dataobj.shape[-1]
213220

214221
if brainmask_file:
215-
mask = nb.load(brainmask_file)
222+
mask = load_api(brainmask_file, SpatialImage)
216223
retval.brainmask = np.asanyarray(mask.dataobj)
217224

218225
return retval

src/nifreeze/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,11 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
222222
class PETMotionEstimator:
223223
"""Estimates motion within PET imaging data aligned with generic Estimator workflow."""
224224

225-
def __init__(self, align_kwargs=None, strategy="lofo"):
225+
def __init__(self, align_kwargs: dict | None = None, strategy: str = "lofo"):
226226
self.align_kwargs = align_kwargs or {}
227227
self.strategy = strategy
228228

229-
def run(self, pet_dataset, omp_nthreads=None):
229+
def run(self, pet_dataset: PET, omp_nthreads: int | None = None) -> list:
230230
n_frames = len(pet_dataset)
231231
frame_indices = np.arange(n_frames)
232232

src/nifreeze/model/pet.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from scipy.interpolate import BSpline
3232
from scipy.sparse.linalg import cg
3333

34+
from nifreeze.data.pet import PET
3435
from nifreeze.model.base import BaseModel
3536

3637
DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
@@ -54,13 +55,13 @@ class PETModel(BaseModel):
5455

5556
def __init__(
5657
self,
57-
dataset,
58-
timepoints=None,
59-
xlim=None,
60-
n_ctrl=None,
61-
order=3,
62-
smooth_fwhm=10,
63-
thresh_pct=20,
58+
dataset: PET,
59+
timepoints: list | np.ndarray | None = None,
60+
xlim: float | None = None,
61+
n_ctrl: int | None = None,
62+
order: int = 3,
63+
smooth_fwhm: float = 10.0,
64+
thresh_pct: float = 20.0,
6465
**kwargs,
6566
):
6667
"""
@@ -105,7 +106,7 @@ def __init__(
105106
self._mask = None
106107

107108
@property
108-
def is_fitted(self):
109+
def is_fitted(self) -> bool:
109110
return self._locked_fit is not None
110111

111112
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
@@ -117,7 +118,7 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
117118
if index is not None:
118119
raise NotImplementedError("Fitting with held-out data is not supported")
119120
timepoints = kwargs.get("timepoints", None) or self._x
120-
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl
121+
x = np.asarray((np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl)
121122

122123
data = self._dataset.dataobj
123124
brainmask = self._dataset.brainmask

0 commit comments

Comments
 (0)