Skip to content

Commit 4751545

Browse files
committed
REF: Refactor PET model
Refactor PET model: use a base class that contains the essential properties for a PET model and create a derived class that implements the B-Spline correction. - Make the `x_lim` private attribute be an array when assigning the corresponding value at initialization to enable NumPy broadcasting. - Remove the unnecessary explicit `float` casting around the number of control points: the `dtype="float32"` specifier creates a float array. Fixes: ``` Expected type 'str | Buffer | SupportsFloat | SupportsIndex', got 'Literal[0] | None | {__eq__} | int' instead ``` raised by the IDE. - Use `np.asarray` to avoid extra copies when not necessary.
1 parent 19af931 commit 4751545

File tree

6 files changed

+166
-74
lines changed

6 files changed

+166
-74
lines changed

docs/notebooks/pet_motion_estimation.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,9 @@
406406
}
407407
],
408408
"source": [
409-
"from nifreeze.model import PETModel\n",
409+
"from nifreeze.model import BSplinePETModel\n",
410410
"\n",
411-
"model = PETModel(dataset=pet_dataset, timepoints=pet_dataset.midframe, xlim=7000)"
411+
"model = BSplinePETModel(dataset=pet_dataset, timepoints=pet_dataset.midframe, xlim=7000)"
412412
]
413413
},
414414
{

src/nifreeze/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from nifreeze.data.base import BaseDataset
4141
from nifreeze.data.pet import PET
4242
from nifreeze.model.base import BaseModel, ModelFactory
43-
from nifreeze.model.pet import PETModel
43+
from nifreeze.model.pet import BSplinePETModel
4444
from nifreeze.registration.ants import (
4545
Registration,
4646
_prepare_registration_data,
@@ -255,8 +255,8 @@ def run(self, pet_dataset, omp_nthreads=None):
255255
total_duration=pet_dataset.total_duration,
256256
)
257257

258-
# Instantiate PETModel explicitly
259-
model = PETModel(
258+
# Instantiate the PET model explicitly
259+
model = BSplinePETModel(
260260
dataset=train_dataset,
261261
timepoints=train_times,
262262
xlim=pet_dataset.total_duration,

src/nifreeze/model/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
DTIModel,
3434
GPModel,
3535
)
36-
from nifreeze.model.pet import PETModel
36+
from nifreeze.model.pet import BSplinePETModel
3737

3838
__all__ = (
3939
"ModelFactory",
@@ -43,5 +43,5 @@
4343
"DTIModel",
4444
"GPModel",
4545
"TrivialModel",
46-
"PETModel",
46+
"BSplinePETModel",
4747
)

src/nifreeze/model/pet.py

Lines changed: 143 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
#
2323
"""Models for nuclear imaging."""
2424

25+
from abc import ABC, ABCMeta, abstractmethod
2526
from os import cpu_count
27+
from typing import Union
2628

2729
import nibabel as nb
2830
import numpy as np
@@ -31,62 +33,90 @@
3133
from scipy.interpolate import BSpline
3234
from scipy.sparse.linalg import cg
3335

36+
from nifreeze.data.pet import PET
3437
from nifreeze.model.base import BaseModel
3538

3639
DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
3740
"""Time frame tolerance in seconds."""
3841

3942

40-
class PETModel(BaseModel):
41-
"""A PET imaging realignment model based on B-Spline approximation."""
43+
def _exec_fit(model, data, chunk=None, **kwargs):
44+
return model.fit(data, **kwargs), chunk
4245

43-
__slots__ = (
44-
"_t",
45-
"_x",
46-
"_xlim",
47-
"_order",
48-
"_n_ctrl",
49-
"_datashape",
50-
"_mask",
51-
"_smooth_fwhm",
52-
"_thresh_pct",
53-
)
46+
47+
def _exec_predict(model, chunk=None, **kwargs):
48+
"""Propagate model parameters and call predict."""
49+
return np.squeeze(model.predict(**kwargs)), chunk
50+
51+
52+
class BasePETModel(BaseModel, ABC):
53+
"""Interface and default methods for PET models."""
54+
55+
__metaclass__ = ABCMeta
56+
57+
__slots__ = {
58+
"_data_mask": "A mask for the voxels that will be fitted and predicted",
59+
"_x": "",
60+
"_xlim": "",
61+
"_smooth_fwhm": "FWHM in mm over which to smooth",
62+
"_thresh_pct": "Thresholding percentile for the signal",
63+
"_model_class": "Defining a model class",
64+
"_modelargs": "Arguments acceptable by the underlying model",
65+
"_models": "List with one or more (if parallel execution) model instances",
66+
}
5467

5568
def __init__(
5669
self,
57-
dataset,
58-
timepoints=None,
59-
xlim=None,
60-
n_ctrl=None,
61-
order=3,
62-
smooth_fwhm=10,
63-
thresh_pct=20,
70+
dataset: PET,
71+
timepoints: list | np.ndarray | None = None, ## Is there a way to use array-like
72+
xlim: list | np.ndarray | None = None,
73+
smooth_fwhm: float = 10.0,
74+
thresh_pct: float = 20.0,
6475
**kwargs,
6576
):
66-
"""
67-
Create the B-Spline interpolating matrix.
77+
"""Initialization.
6878
69-
Parameters:
70-
-----------
71-
timepoints : :obj:`list`
79+
Parameters
80+
----------
81+
timepoints : :obj:`list` or :obj:`~np.ndarray`
7282
The timing (in sec) of each PET volume.
7383
E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330.,
7484
420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]``
75-
76-
n_ctrl : :obj:`int`
77-
Number of B-Spline control points. If `None`, then one control point every
78-
six timepoints will be used. The less control points, the smoother is the
79-
model.
80-
85+
xlim : .
86+
.
87+
smooth_fwhm : obj:`float`
88+
FWHM in mm over which to smooth the signal.
89+
thresh_pct : obj:`float`
90+
Thresholding percentile for the signal.
8191
"""
92+
8293
super().__init__(dataset, **kwargs)
8394

95+
# Duck typing, instead of explicitly testing for PET type
96+
if not hasattr(dataset, "total_duration"):
97+
raise TypeError("Dataset MUST be a PET object.")
98+
99+
if not hasattr(dataset, "midframe"):
100+
raise ValueError("Dataset MUST have a midframe.")
101+
102+
# ToDO
103+
# Are the timepoints your "gradients" ??? If so, can they be computed
104+
# from frame_time or frame_duration
105+
# Or else frame_time and frame_duration ????
106+
107+
self._data_mask = (
108+
dataset.brainmask
109+
if dataset.brainmask is not None
110+
else np.ones(dataset.dataobj.shape[:3], dtype=bool)
111+
)
112+
113+
# ToDo
114+
# Are timepoints and xlim features that all PET models require ??
84115
if timepoints is None or xlim is None:
85-
raise TypeError("timepoints must be provided in initialization")
116+
raise ValueError("`timepoints` and `xlim` must be specified and have a nonzero value.")
86117

87-
self._order = order
88-
self._x = np.array(timepoints, dtype="float32")
89-
self._xlim = xlim
118+
self._x = np.asarray(timepoints, dtype="float32")
119+
self._xlim = np.asarray(xlim)
90120
self._smooth_fwhm = smooth_fwhm
91121
self._thresh_pct = thresh_pct
92122

@@ -95,62 +125,114 @@ def __init__(
95125
if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL):
96126
raise ValueError("Last frame midpoint should not be equal or greater than duration")
97127

98-
# Calculate index coordinates in the B-Spline grid
99-
self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1
128+
def _preproces_data(self) -> np.ndarray:
129+
# ToDo
130+
# data, _, gtab = self._dataset[idxmask] ### This needs the PET data model to be changed
131+
data = self._dataset.dataobj
132+
brainmask = self._dataset.brainmask
100133

101-
# B-Spline knots
102-
self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32")
134+
# Preprocess the data
135+
if self._smooth_fwhm > 0:
136+
smoothed_img = smooth_image(
137+
nb.Nifti1Image(data, self._dataset.affine), self._smooth_fwhm
138+
)
139+
data = smoothed_img.get_fdata()
140+
141+
if self._thresh_pct > 0:
142+
thresh_val = np.percentile(data, self._thresh_pct)
143+
data[data < thresh_val] = 0
103144

104-
self._datashape = None
105-
self._mask = None
145+
# Convert data into V (voxels) x T (timepoints)
146+
return data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask]
106147

107148
@property
108-
def is_fitted(self):
149+
def is_fitted(self) -> bool:
109150
return self._locked_fit is not None
110151

152+
@abstractmethod
153+
def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]:
154+
"""Predict the corrected volume."""
155+
return None
156+
157+
158+
class BSplinePETModel(BasePETModel):
159+
"""A PET imaging realignment model based on B-Spline approximation."""
160+
161+
__slots__ = (
162+
"_t",
163+
"_order",
164+
"_n_ctrl",
165+
)
166+
167+
def __init__(
168+
self,
169+
dataset: PET,
170+
n_ctrl: int | None = None,
171+
order: int = 3,
172+
**kwargs,
173+
):
174+
"""Create the B-Spline interpolating matrix.
175+
176+
Parameters
177+
----------
178+
n_ctrl : :obj:`int`
179+
Number of B-Spline control points. If `None`, then one control point every
180+
six timepoints will be used. The less control points, the smoother is the
181+
model.
182+
order : :obj:`int`
183+
Order of the B-Spline approximation.
184+
"""
185+
186+
super().__init__(dataset, **kwargs)
187+
188+
self._order = order
189+
190+
# Calculate index coordinates in the B-Spline grid
191+
self._n_ctrl = n_ctrl or (len(self._x) // 4) + 1
192+
193+
# B-Spline knots
194+
self._t = np.arange(-3, self._n_ctrl + 4, dtype="float32")
195+
111196
def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
112197
"""Fit the model."""
113198

199+
n_jobs = n_jobs or min(cpu_count() or 1, 8)
200+
114201
if self._locked_fit is not None:
115202
return n_jobs
116203

117204
if index is not None:
118205
raise NotImplementedError("Fitting with held-out data is not supported")
119-
timepoints = kwargs.get("timepoints", None) or self._x
120-
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl
121206

122-
data = self._dataset.dataobj
123-
brainmask = self._dataset.brainmask
207+
data = self._preproces_data()
124208

125-
if self._smooth_fwhm > 0:
126-
smoothed_img = smooth_image(
127-
nb.Nifti1Image(data, self._dataset.affine), self._smooth_fwhm
128-
)
129-
data = smoothed_img.get_fdata()
130-
131-
if self._thresh_pct > 0:
132-
thresh_val = np.percentile(data, self._thresh_pct)
133-
data[data < thresh_val] = 0
134-
135-
# Convert data into V (voxels) x T (timepoints)
136-
data = data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask]
209+
# ToDo
210+
# Does not make sense to make timepoints be a kwarg if it is provided as a named parameter to __init__
211+
timepoints = kwargs.get("timepoints", None) or self._x
212+
x = np.asarray(timepoints, dtype="float32") / self._xlim * self._n_ctrl
137213

138214
# A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding)
139215
A = BSpline.design_matrix(x, self._t, k=self._order)
140216
AT = A.T
141217
ATdotA = AT @ A
142218

143219
# Parallelize process with joblib
144-
with Parallel(n_jobs=n_jobs or min(cpu_count() or 1, 8)) as executor:
220+
with Parallel(n_jobs=n_jobs) as executor:
145221
results = executor(delayed(cg)(ATdotA, AT @ v) for v in data)
146222

147-
self._locked_fit = np.array([r[0] for r in results])
223+
self._locked_fit = np.asarray([r[0] for r in results])
148224

149225
return n_jobs
150226

151-
def fit_predict(self, index: int | None = None, **kwargs):
227+
def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]:
152228
"""Return the corrected volume using B-spline interpolation."""
153229

230+
# ToDo
231+
# Does the below apply to PET ? Martin has the return None statement
232+
# if index is None:
233+
# raise RuntimeError(
234+
# f"Model {self.__class__.__name__} does not allow locking.")
235+
154236
# Fit the BSpline basis on all data
155237
if self._locked_fit is None:
156238
self._fit(index, n_jobs=kwargs.pop("n_jobs", None), **kwargs)

test/test_integration_pet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def fit_predict(self, index):
7878
return None
7979
return np.zeros(ds.shape3d, dtype=np.float32)
8080

81-
monkeypatch.setattr("nifreeze.estimator.PETModel", DummyModel)
81+
monkeypatch.setattr("nifreeze.estimator.BSplinePETModel", DummyModel)
8282

8383
class DummyRegistration:
8484
def __init__(self, *args, **kwargs):

test/test_model_pet.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import pytest
2626

2727
from nifreeze.data.pet import PET
28-
from nifreeze.model.pet import PETModel
28+
from nifreeze.model.pet import BSplinePETModel
2929

3030

3131
def _create_dataset():
@@ -43,9 +43,19 @@ def _create_dataset():
4343
)
4444

4545

46+
def test_pet_base_model():
47+
from nifreeze.model.pet import BasePETModel
48+
49+
with pytest.raises(
50+
TypeError,
51+
match="Can't instantiate abstract class BasePETModel with abstract method fit_predict",
52+
):
53+
BasePETModel(None)
54+
55+
4656
def test_petmodel_fit_predict():
4757
dataset = _create_dataset()
48-
model = PETModel(
58+
model = BSplinePETModel(
4959
dataset=dataset,
5060
timepoints=dataset.midframe,
5161
xlim=dataset.total_duration,
@@ -65,12 +75,12 @@ def test_petmodel_fit_predict():
6575

6676
def test_petmodel_invalid_init():
6777
dataset = _create_dataset()
68-
with pytest.raises(TypeError):
69-
PETModel(dataset=dataset)
78+
with pytest.raises(ValueError):
79+
BSplinePETModel(dataset=dataset)
7080

7181

7282
def test_petmodel_time_check():
7383
dataset = _create_dataset()
7484
bad_times = np.array([0, 10, 20, 30, 50], dtype=np.float32)
7585
with pytest.raises(ValueError):
76-
PETModel(dataset=dataset, timepoints=bad_times, xlim=60.0)
86+
BSplinePETModel(dataset=dataset, timepoints=bad_times, xlim=60.0)

0 commit comments

Comments
 (0)