Skip to content

Commit 10af03f

Browse files
committed
REF: Refactor PET model
Refactor PET model: use a base class that contains the essential properties for a PET model and create a derived class that implements the B-Spline correction. - Remove the unnecessary explicit `float` casting around the number of control points: the `dtype="float32"` specifier creates a float array. Fixes: ``` Expected type 'str | Buffer | SupportsFloat | SupportsIndex', got 'Literal[0] | None | {__eq__} | int' instead ``` raised by the IDE.
1 parent bb3d129 commit 10af03f

File tree

6 files changed

+157
-72
lines changed

6 files changed

+157
-72
lines changed

docs/notebooks/pet_motion_estimation.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,9 @@
406406
}
407407
],
408408
"source": [
409-
"from nifreeze.model import PETModel\n",
409+
"from nifreeze.model import BSplinePETModel\n",
410410
"\n",
411-
"model = PETModel(dataset=pet_dataset, timepoints=pet_dataset.midframe, xlim=7000)"
411+
"model = BSplinePETModel(dataset=pet_dataset, timepoints=pet_dataset.midframe, xlim=7000)"
412412
]
413413
},
414414
{
@@ -2258,7 +2258,7 @@
22582258
"from nifreeze.estimator import PETMotionEstimator\n",
22592259
"\n",
22602260
"# Instantiate with a PETModel or appropriate model instance\n",
2261-
"model = PETModel(\n",
2261+
"model = BSplinePETModel(\n",
22622262
" dataset=pet_dataset, timepoints=pet_dataset.midframe, xlim=pet_dataset.total_duration\n",
22632263
")\n",
22642264
"estimator = PETMotionEstimator(model=model)\n",

src/nifreeze/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from nifreeze.data.base import BaseDataset
4141
from nifreeze.data.pet import PET
4242
from nifreeze.model.base import BaseModel, ModelFactory
43-
from nifreeze.model.pet import PETModel
43+
from nifreeze.model.pet import BSplinePETModel
4444
from nifreeze.registration.ants import (
4545
Registration,
4646
_prepare_registration_data,
@@ -255,8 +255,8 @@ def run(self, pet_dataset, omp_nthreads=None):
255255
total_duration=pet_dataset.total_duration,
256256
)
257257

258-
# Instantiate PETModel explicitly
259-
model = PETModel(
258+
# Instantiate the PET model explicitly
259+
model = BSplinePETModel(
260260
dataset=train_dataset,
261261
timepoints=train_times,
262262
xlim=pet_dataset.total_duration,

src/nifreeze/model/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
DTIModel,
3434
GPModel,
3535
)
36-
from nifreeze.model.pet import PETModel
36+
from nifreeze.model.pet import BSplinePETModel
3737

3838
__all__ = (
3939
"ModelFactory",
@@ -43,5 +43,5 @@
4343
"DTIModel",
4444
"GPModel",
4545
"TrivialModel",
46-
"PETModel",
46+
"BSplinePETModel",
4747
)

src/nifreeze/model/pet.py

Lines changed: 143 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
#
2323
"""Models for nuclear imaging."""
2424

25+
import abc
26+
from abc import ABC
2527
from os import cpu_count
28+
from typing import Union
2629

2730
import nibabel as nb
2831
import numpy as np
@@ -31,60 +34,88 @@
3134
from scipy.interpolate import BSpline
3235
from scipy.sparse.linalg import cg
3336

37+
from nifreeze.data.pet import PET
3438
from nifreeze.model.base import BaseModel
3539

3640
DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
3741
"""Time frame tolerance in seconds."""
3842

3943

40-
class PETModel(BaseModel):
41-
"""A PET imaging realignment model based on B-Spline approximation."""
44+
def _exec_fit(model, data, chunk=None, **kwargs):
45+
return model.fit(data, **kwargs), chunk
4246

43-
__slots__ = (
44-
"_t",
45-
"_x",
46-
"_xlim",
47-
"_order",
48-
"_n_ctrl",
49-
"_datashape",
50-
"_mask",
51-
"_smooth_fwhm",
52-
"_thresh_pct",
53-
)
47+
48+
def _exec_predict(model, chunk=None, **kwargs):
49+
"""Propagate model parameters and call predict."""
50+
return np.squeeze(model.predict(**kwargs)), chunk
51+
52+
53+
class BasePETModel(BaseModel, ABC):
54+
"""Interface and default methods for PET models."""
55+
56+
__metaclass__ = abc.ABCMeta
57+
58+
__slots__ = {
59+
"_data_mask": "A mask for the voxels that will be fitted and predicted",
60+
"_x": "",
61+
"_xlim": "",
62+
"_smooth_fwhm": "FWHM in mm over which to smooth",
63+
"_thresh_pct": "Thresholding percentile for the signal",
64+
"_model_class": "Defining a model class",
65+
"_modelargs": "Arguments acceptable by the underlying model",
66+
"_models": "List with one or more (if parallel execution) model instances",
67+
}
5468

5569
def __init__(
5670
self,
57-
dataset,
58-
timepoints=None,
59-
xlim=None,
60-
n_ctrl=None,
61-
order=3,
62-
smooth_fwhm=10,
63-
thresh_pct=20,
71+
dataset: PET,
72+
timepoints: list | np.ndarray = None, ## Is there a way to use array-like
73+
xlim: list | np.ndarray = None,
74+
smooth_fwhm: float = 10.0,
75+
thresh_pct: float = 20.0,
6476
**kwargs,
6577
):
66-
"""
67-
Create the B-Spline interpolating matrix.
78+
"""Initialization.
6879
69-
Parameters:
70-
-----------
71-
timepoints : :obj:`list`
80+
Parameters
81+
----------
82+
timepoints : :obj:`list` or :obj:`~np.ndarray`
7283
The timing (in sec) of each PET volume.
7384
E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330.,
7485
420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]``
75-
76-
n_ctrl : :obj:`int`
77-
Number of B-Spline control points. If `None`, then one control point every
78-
six timepoints will be used. The less control points, the smoother is the
79-
model.
80-
86+
xlim : .
87+
.
88+
smooth_fwhm : obj:`float`
89+
FWHM in mm over which to smooth the signal.
90+
thresh_pct : obj:`float`
91+
Thresholding percentile for the signal.
8192
"""
93+
8294
super().__init__(dataset, **kwargs)
8395

96+
# Duck typing, instead of explicitly testing for PET type
97+
if not hasattr(dataset, "total_duration"):
98+
raise TypeError("Dataset MUST be a PET object.")
99+
100+
if not hasattr(dataset, "midframe"):
101+
raise ValueError("Dataset MUST have a midframe.")
102+
103+
# ToDO
104+
# Are the timepoints your "gradients" ??? If so, can they be computed
105+
# from frame_time or frame_duration
106+
# Or else frame_time and frame_duration ????
107+
108+
self._data_mask = (
109+
dataset.brainmask
110+
if dataset.brainmask is not None
111+
else np.ones(dataset.dataobj.shape[:3], dtype=bool)
112+
)
113+
114+
# ToDo
115+
# Are timepoints and xlim features that all PET models require ??
84116
if timepoints is None or xlim is None:
85-
raise TypeError("timepoints must be provided in initialization")
117+
raise ValueError("`timepoints` and `xlim` must be specified and have a nonzero value.")
86118

87-
self._order = order
88119
self._x = np.array(timepoints, dtype="float32")
89120
self._xlim = xlim
90121
self._smooth_fwhm = smooth_fwhm
@@ -95,33 +126,15 @@ def __init__(
95126
if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL):
96127
raise ValueError("Last frame midpoint should not be equal or greater than duration")
97128

98-
# Calculate index coordinates in the B-Spline grid
99-
self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1
100-
101-
# B-Spline knots
102-
self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32")
103-
104-
self._datashape = None
105-
self._mask = None
106-
107-
@property
108-
def is_fitted(self):
109-
return self._locked_fit is not None
110-
111-
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
112-
"""Fit the model."""
113-
114-
if self._locked_fit is not None:
115-
return n_jobs
116-
117-
if index is not None:
118-
raise NotImplementedError("Fitting with held-out data is not supported")
119-
timepoints = kwargs.get("timepoints", None) or self._x
120-
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl
129+
super().__init__(dataset, **kwargs)
121130

131+
def _preproces_data(self) -> np.ndarray:
132+
# ToDo
133+
# data, _, gtab = self._dataset[idxmask] ### This needs the PET data model to be changed
122134
data = self._dataset.dataobj
123135
brainmask = self._dataset.brainmask
124136

137+
# Preprocess the data
125138
if self._smooth_fwhm > 0:
126139
smoothed_img = smooth_image(
127140
nb.Nifti1Image(data, self._dataset.affine), self._smooth_fwhm
@@ -133,7 +146,73 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
133146
data[data < thresh_val] = 0
134147

135148
# Convert data into V (voxels) x T (timepoints)
136-
data = data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask]
149+
return data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask]
150+
151+
@property
152+
def is_fitted(self) -> bool:
153+
return self._locked_fit is not None
154+
155+
@abc.abstractmethod
156+
def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]:
157+
"""Predict the corrected volume."""
158+
return
159+
160+
161+
class BSplinePETModel(BasePETModel):
162+
"""A PET imaging realignment model based on B-Spline approximation."""
163+
164+
__slots__ = (
165+
"_t",
166+
"_order",
167+
"_n_ctrl",
168+
)
169+
170+
def __init__(
171+
self,
172+
dataset: PET,
173+
n_ctrl: int = None,
174+
order: int = 3,
175+
**kwargs,
176+
):
177+
"""Create the B-Spline interpolating matrix.
178+
179+
Parameters
180+
----------
181+
n_ctrl : :obj:`int`
182+
Number of B-Spline control points. If `None`, then one control point every
183+
six timepoints will be used. The less control points, the smoother is the
184+
model.
185+
order : :obj:`int`
186+
Order of the B-Spline approximation.
187+
"""
188+
189+
super().__init__(dataset, **kwargs)
190+
191+
self._order = order
192+
193+
# Calculate index coordinates in the B-Spline grid
194+
self._n_ctrl = n_ctrl or (len(self._x) // 4) + 1
195+
196+
# B-Spline knots
197+
self._t = np.arange(-3, self._n_ctrl + 4, dtype="float32")
198+
199+
def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> Union[int, None]:
200+
"""Fit the model."""
201+
202+
n_jobs = n_jobs or 1
203+
204+
if self._locked_fit is not None:
205+
return n_jobs
206+
207+
if index is not None:
208+
raise NotImplementedError("Fitting with held-out data is not supported")
209+
210+
data = self._preproces_data()
211+
212+
# ToDo
213+
# Does not make sense to make timepoints be a kwarg if it is provided as a named parameter to __init__
214+
timepoints = kwargs.get("timepoints", None) or self._x
215+
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl
137216

138217
# A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding)
139218
A = BSpline.design_matrix(x, self._t, k=self._order)
@@ -146,9 +225,15 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
146225

147226
self._locked_fit = np.array([r[0] for r in results])
148227

149-
def fit_predict(self, index: int | None = None, **kwargs):
228+
def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]:
150229
"""Return the corrected volume using B-spline interpolation."""
151230

231+
# ToDo
232+
# Does the below apply to PET ? Martin has the return None statement
233+
# if index is None:
234+
# raise RuntimeError(
235+
# f"Model {self.__class__.__name__} does not allow locking.")
236+
152237
# Fit the BSpline basis on all data
153238
if self._locked_fit is None:
154239
self._fit(index, n_jobs=kwargs.pop("n_jobs", None), **kwargs)

test/test_integration_pet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def fit_predict(self, index):
7878
return None
7979
return np.zeros(ds.shape3d, dtype=np.float32)
8080

81-
monkeypatch.setattr("nifreeze.estimator.PETModel", DummyModel)
81+
monkeypatch.setattr("nifreeze.estimator.BSplinePETModel", DummyModel)
8282

8383
class DummyRegistration:
8484
def __init__(self, *args, **kwargs):

test/test_model_pet.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pytest
2626

2727
from nifreeze.data.pet import PET
28-
from nifreeze.model.pet import PETModel
28+
from nifreeze.model.pet import BSplinePETModel
2929

3030

3131
def _create_dataset():
@@ -45,7 +45,7 @@ def _create_dataset():
4545

4646
def test_petmodel_fit_predict():
4747
dataset = _create_dataset()
48-
model = PETModel(
48+
model = BSplinePETModel(
4949
dataset=dataset,
5050
timepoints=dataset.midframe,
5151
xlim=dataset.total_duration,
@@ -65,12 +65,12 @@ def test_petmodel_fit_predict():
6565

6666
def test_petmodel_invalid_init():
6767
dataset = _create_dataset()
68-
with pytest.raises(TypeError):
69-
PETModel(dataset=dataset)
68+
with pytest.raises(ValueError):
69+
BSplinePETModel(dataset=dataset)
7070

7171

7272
def test_petmodel_time_check():
7373
dataset = _create_dataset()
7474
bad_times = np.array([0, 10, 20, 30, 50], dtype=np.float32)
7575
with pytest.raises(ValueError):
76-
PETModel(dataset=dataset, timepoints=bad_times, xlim=60.0)
76+
BSplinePETModel(dataset=dataset, timepoints=bad_times, xlim=60.0)

0 commit comments

Comments
 (0)