Skip to content

Commit f6199eb

Browse files
authored
Merge pull request #189 from jhlegarreta/sty/add-pet-type-hints
STY: Add type hints to PET-related classes and functions
2 parents 0daaaeb + ff4a9fa commit f6199eb

File tree

3 files changed

+29
-23
lines changed

3 files changed

+29
-23
lines changed

src/nifreeze/data/pet.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@
2727
import json
2828
from collections import namedtuple
2929
from pathlib import Path
30-
from typing import Callable
30+
from typing import Any, Callable
3131

3232
import attrs
3333
import h5py
3434
import nibabel as nb
3535
import numpy as np
3636
from nibabel.spatialimages import SpatialImage
3737
from nitransforms.linear import Affine
38+
from typing_extensions import Self
3839

3940
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
4041
from nifreeze.utils.ndimage import load_api
@@ -123,11 +124,10 @@ def lofo_split(self, index):
123124

124125
return (train_data, train_timings), (test_data, test_timing)
125126

126-
def set_transform(self, index, affine, order=3):
127+
def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
127128
"""Set an affine, and update data object and gradients."""
128-
reference = namedtuple("ImageGrid", ("shape", "affine"))(
129-
shape=self.dataobj.shape[:3], affine=self.affine
130-
)
129+
ImageGrid = namedtuple("ImageGrid", ("shape", "affine"))
130+
reference = ImageGrid(self.dataobj.shape[:3], self.affine)
131131
xform = Affine(matrix=affine, reference=reference)
132132

133133
if not Path(self._filepath).exists():
@@ -152,7 +152,9 @@ def set_transform(self, index, affine, order=3):
152152

153153
self.motion_affines[index] = xform
154154

155-
def to_filename(self, filename, compression=None, compression_opts=None):
155+
def to_filename(
156+
self, filename: Path | str, compression: str | None = None, compression_opts: Any = None
157+
) -> None:
156158
"""Write an HDF5 file to disk."""
157159
filename = Path(filename)
158160
if not filename.name.endswith(".h5"):
@@ -183,21 +185,23 @@ def to_nifti(self, filename, *_):
183185
nii.to_filename(filename)
184186

185187
@classmethod
186-
def from_filename(cls, filename):
188+
def from_filename(cls, filename: Path | str) -> Self:
187189
"""Read an HDF5 file from disk."""
188190
with h5py.File(filename, "r") as in_file:
189191
root = in_file["/0"]
190192
data = {k: np.asanyarray(v) for k, v in root.items() if not k.startswith("_")}
191193
return cls(**data)
192194

193195
@classmethod
194-
def load(cls, filename, json_file, brainmask_file=None):
196+
def load(
197+
cls, filename: Path | str, json_file: Path | str, brainmask_file: Path | str | None = None
198+
) -> Self:
195199
"""Load PET data."""
196200
filename = Path(filename)
197201
if filename.name.endswith(".h5"):
198202
return cls.from_filename(filename)
199203

200-
img = nb.load(filename)
204+
img = load_api(filename, SpatialImage)
201205
retval = cls(
202206
dataobj=img.get_fdata(dtype="float32"),
203207
affine=img.affine,
@@ -217,7 +221,7 @@ def load(cls, filename, json_file, brainmask_file=None):
217221
assert len(retval.midframe) == retval.dataobj.shape[-1]
218222

219223
if brainmask_file:
220-
mask = nb.load(brainmask_file)
224+
mask = load_api(brainmask_file, SpatialImage)
221225
retval.brainmask = np.asanyarray(mask.dataobj)
222226

223227
return retval

src/nifreeze/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,11 +222,11 @@ def run(self, dataset: DatasetT, **kwargs) -> Self:
222222
class PETMotionEstimator:
223223
"""Estimates motion within PET imaging data aligned with generic Estimator workflow."""
224224

225-
def __init__(self, align_kwargs=None, strategy="lofo"):
225+
def __init__(self, align_kwargs: dict | None = None, strategy: str = "lofo"):
226226
self.align_kwargs = align_kwargs or {}
227227
self.strategy = strategy
228228

229-
def run(self, pet_dataset, omp_nthreads=None):
229+
def run(self, pet_dataset: PET, omp_nthreads: int | None = None) -> list:
230230
n_frames = len(pet_dataset)
231231
frame_indices = np.arange(n_frames)
232232

src/nifreeze/model/pet.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
"""Models for nuclear imaging."""
2424

2525
from os import cpu_count
26+
from typing import Union
2627

2728
import nibabel as nb
2829
import numpy as np
@@ -31,6 +32,7 @@
3132
from scipy.interpolate import BSpline
3233
from scipy.sparse.linalg import cg
3334

35+
from nifreeze.data.pet import PET
3436
from nifreeze.model.base import BaseModel
3537

3638
DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
@@ -54,13 +56,13 @@ class PETModel(BaseModel):
5456

5557
def __init__(
5658
self,
57-
dataset,
58-
timepoints=None,
59-
xlim=None,
60-
n_ctrl=None,
61-
order=3,
62-
smooth_fwhm=10,
63-
thresh_pct=20,
59+
dataset: PET,
60+
timepoints: list | np.ndarray | None = None,
61+
xlim: float | None = None,
62+
n_ctrl: int | None = None,
63+
order: int = 3,
64+
smooth_fwhm: float = 10.0,
65+
thresh_pct: float = 20.0,
6466
**kwargs,
6567
):
6668
"""
@@ -105,7 +107,7 @@ def __init__(
105107
self._mask = None
106108

107109
@property
108-
def is_fitted(self):
110+
def is_fitted(self) -> bool:
109111
return self._locked_fit is not None
110112

111113
def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
@@ -117,7 +119,7 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
117119
if index is not None:
118120
raise NotImplementedError("Fitting with held-out data is not supported")
119121
timepoints = kwargs.get("timepoints", None) or self._x
120-
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl
122+
x = np.asarray((np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl)
121123

122124
data = self._dataset.dataobj
123125
brainmask = self._dataset.brainmask
@@ -148,7 +150,7 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
148150

149151
return n_jobs
150152

151-
def fit_predict(self, index: int | None = None, **kwargs):
153+
def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]:
152154
"""Return the corrected volume using B-spline interpolation."""
153155

154156
# Fit the BSpline basis on all data
@@ -159,7 +161,7 @@ def fit_predict(self, index: int | None = None, **kwargs):
159161
return None
160162

161163
# Project sample timing into B-Spline coordinates
162-
x = (index / self._xlim) * self._n_ctrl
164+
x = np.asarray((index / self._xlim) * self._n_ctrl)
163165
A = BSpline.design_matrix(x, self._t, k=self._order)
164166

165167
# A is 1 (num. timepoints) x C (num. coeff)

0 commit comments

Comments
 (0)