Skip to content

Commit 90eb886

Browse files
committed
REF: Refactor PET model
Refactor PET model: use a base class that contains the essential properties for a PET model and create a derived class that implements the B-Spline correction. - Remove the unnecessary explicit `float` casting around the number of control points: the `dtype="float32"` specifier creates a float array. Fixes: ``` Expected type 'str | Buffer | SupportsFloat | SupportsIndex', got 'Literal[0] | None | {__eq__} | int' instead ``` raised by the IDE.
1 parent 676cd91 commit 90eb886

File tree

5 files changed

+218
-55
lines changed

5 files changed

+218
-55
lines changed

docs/notebooks/pet_motion_estimation.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,9 +406,9 @@
406406
}
407407
],
408408
"source": [
409-
"from nifreeze.model import PETModel\n",
409+
"from nifreeze.model import BSplinePETModel\n",
410410
"\n",
411-
"model = PETModel(dataset=pet_dataset, timepoints=pet_dataset.midframe, xlim=7000)"
411+
"model = BSplinePETModel(dataset=pet_dataset, timepoints=pet_dataset.midframe, xlim=7000)"
412412
]
413413
},
414414
{
@@ -2258,7 +2258,7 @@
22582258
"from nifreeze.estimator import PETMotionEstimator\n",
22592259
"\n",
22602260
"# Instantiate with a PETModel or appropriate model instance\n",
2261-
"model = PETModel(\n",
2261+
"model = BSplinePETModel(\n",
22622262
" dataset=pet_dataset, timepoints=pet_dataset.midframe, xlim=pet_dataset.total_duration\n",
22632263
")\n",
22642264
"estimator = PETMotionEstimator(model=model)\n",

src/nifreeze/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from nifreeze.data.base import BaseDataset
4141
from nifreeze.data.pet import PET
4242
from nifreeze.model.base import BaseModel, ModelFactory
43-
from nifreeze.model.pet import PETModel
43+
from nifreeze.model.pet import BSplinePETModel
4444
from nifreeze.registration.ants import (
4545
Registration,
4646
_prepare_registration_data,
@@ -255,8 +255,8 @@ def run(self, pet_dataset, omp_nthreads=None):
255255
total_duration=pet_dataset.total_duration,
256256
)
257257

258-
# Instantiate PETModel explicitly
259-
model = PETModel(
258+
# Instantiate the PET model explicitly
259+
model = BSplinePETModel(
260260
dataset=train_dataset,
261261
timepoints=train_times,
262262
xlim=pet_dataset.total_duration,

src/nifreeze/model/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
DTIModel,
3434
GPModel,
3535
)
36-
from nifreeze.model.pet import PETModel
36+
from nifreeze.model.pet import BSplinePETModel
3737

3838
__all__ = (
3939
"ModelFactory",
@@ -43,5 +43,5 @@
4343
"DTIModel",
4444
"GPModel",
4545
"TrivialModel",
46-
"PETModel",
46+
"BSplinePETModel",
4747
)

src/nifreeze/model/pet.py

Lines changed: 205 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222
#
2323
"""Models for nuclear imaging."""
2424

25+
from importlib import import_module
2526
from os import cpu_count
27+
from typing import Any
2628

2729
import nibabel as nb
2830
import numpy as np
@@ -31,60 +33,86 @@
3133
from scipy.interpolate import BSpline
3234
from scipy.sparse.linalg import cg
3335

36+
from nifreeze.data.pet import PET
3437
from nifreeze.model.base import BaseModel
3538

3639
DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
3740
"""Time frame tolerance in seconds."""
3841

3942

40-
class PETModel(BaseModel):
41-
"""A PET imaging realignment model based on B-Spline approximation."""
43+
def _exec_fit(model, data, chunk=None, **kwargs):
44+
return model.fit(data, **kwargs), chunk
4245

43-
__slots__ = (
44-
"_t",
45-
"_x",
46-
"_xlim",
47-
"_order",
48-
"_n_ctrl",
49-
"_datashape",
50-
"_mask",
51-
"_smooth_fwhm",
52-
"_thresh_pct",
53-
)
46+
47+
def _exec_predict(model, chunk=None, **kwargs):
48+
"""Propagate model parameters and call predict."""
49+
return np.squeeze(model.predict(**kwargs)), chunk
50+
51+
52+
class BasePETModel(BaseModel):
53+
"""Interface and default methods for PET models."""
54+
55+
__slots__ = {
56+
"_data_mask": "A mask for the voxels that will be fitted and predicted",
57+
"_x" : "",
58+
"_xlim" : "",
59+
"_smooth_fwhm" : "FWHM in mm over which to smooth",
60+
"_thresh_pct" : "Thresholding percentile for the signal",
61+
"_model_class": "Defining a model class",
62+
"_modelargs": "Arguments acceptable by the underlying model",
63+
"_models": "List with one or more (if parallel execution) model instances",
64+
}
5465

5566
def __init__(
5667
self,
57-
dataset,
58-
timepoints=None,
59-
xlim=None,
60-
n_ctrl=None,
61-
order=3,
62-
smooth_fwhm=10,
63-
thresh_pct=20,
68+
dataset: PET,
69+
timepoints: list | np.ndarray = None, ## Is there a way to use array-like
70+
xlim: list | np.ndarray = None,
71+
smooth_fwhm: float = 10.0,
72+
thresh_pct : float = 20.0,
6473
**kwargs,
6574
):
66-
"""
67-
Create the B-Spline interpolating matrix.
75+
"""Initialization.
6876
69-
Parameters:
70-
-----------
71-
timepoints : :obj:`list`
77+
Parameters
78+
----------
79+
timepoints : :obj:`list` or :obj:`~np.ndarray`
7280
The timing (in sec) of each PET volume.
7381
E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330.,
7482
420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]``
75-
76-
n_ctrl : :obj:`int`
77-
Number of B-Spline control points. If `None`, then one control point every
78-
six timepoints will be used. The less control points, the smoother is the
79-
model.
80-
83+
xlim : .
84+
.
85+
smooth_fwhm : obj:`float`
86+
FWHM in mm over which to smooth the signal.
87+
thresh_pct : obj:`float`
88+
Thresholding percentile for the signal.
8189
"""
90+
8291
super().__init__(dataset, **kwargs)
8392

93+
# Duck typing, instead of explicitly testing for PET type
94+
if not hasattr(dataset, "total_duration"):
95+
raise TypeError("Dataset MUST be a PET object.")
96+
97+
if not hasattr(dataset, "midframe"):
98+
raise ValueError("Dataset MUST have a midframe.")
99+
100+
# ToDO
101+
# Are the timepoints your "gradients" ??? If so, can they be computed
102+
# from frame_time or frame_duration
103+
# Or else frame_time and frame_duration ????
104+
105+
self._data_mask = (
106+
dataset.brainmask
107+
if dataset.brainmask is not None
108+
else np.ones(dataset.dataobj.shape[:3], dtype=bool)
109+
)
110+
111+
# ToDo
112+
# Are timepoints and xlim features that all PET models require ??
84113
if timepoints is None or xlim is None:
85-
raise TypeError("timepoints must be provided in initialization")
114+
raise ValueError("`timepoints` and `xlim` must be specified and have a nonzero value.")
86115

87-
self._order = order
88116
self._x = np.array(timepoints, dtype="float32")
89117
self._xlim = xlim
90118
self._smooth_fwhm = smooth_fwhm
@@ -95,14 +123,7 @@ def __init__(
95123
if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL):
96124
raise ValueError("Last frame midpoint should not be equal or greater than duration")
97125

98-
# Calculate index coordinates in the B-Spline grid
99-
self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1
100-
101-
# B-Spline knots
102-
self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32")
103-
104-
self._datashape = None
105-
self._mask = None
126+
super().__init__(dataset, **kwargs)
106127

107128
@property
108129
def is_fitted(self):
@@ -111,17 +132,24 @@ def is_fitted(self):
111132
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
112133
"""Fit the model."""
113134

135+
n_jobs = n_jobs or 1
136+
114137
if self._locked_fit is not None:
115138
return n_jobs
116139

117140
if index is not None:
118141
raise NotImplementedError("Fitting with held-out data is not supported")
142+
143+
# ToDo
144+
# Does not make sense to make timepoints be a kwarg if it is provided as a named parameter to __init__
119145
timepoints = kwargs.get("timepoints", None) or self._x
120-
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl
121146

147+
# ToDo
148+
# data, _, gtab = self._dataset[idxmask] ### This needs the PET data model to be changed
122149
data = self._dataset.dataobj
123150
brainmask = self._dataset.brainmask
124151

152+
# Preprocess the data
125153
if self._smooth_fwhm > 0:
126154
smoothed_img = smooth_image(
127155
nb.Nifti1Image(data, self._dataset.affine), self._smooth_fwhm
@@ -135,6 +163,135 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
135163
# Convert data into V (voxels) x T (timepoints)
136164
data = data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask]
137165

166+
# ToDo
167+
# What is the gtab equivalent of PET ?
168+
model_str = getattr(self, "_model_class", "")
169+
module_name, class_name = model_str.rsplit(".", 1)
170+
model = getattr(
171+
import_module(module_name),
172+
class_name,
173+
)(gtab, **kwargs)
174+
175+
fit_kwargs: dict[str, Any] = {} # Add here keyword arguments
176+
177+
# Split data into chunks of group of slices
178+
data_chunks = np.array_split(data, n_jobs)
179+
180+
self._models = [None] * n_jobs
181+
182+
# Parallelize process with joblib
183+
with Parallel(n_jobs=n_jobs) as executor:
184+
results = executor(
185+
delayed(_exec_fit)(model, dchunk, i, **fit_kwargs)
186+
for i, dchunk in enumerate(data_chunks)
187+
)
188+
for submodel, rindex in results:
189+
self._models[rindex] = submodel
190+
191+
return n_jobs
192+
193+
194+
def fit_predict(self, index: int | None = None, **kwargs):
195+
"""Return the corrected volume using B-spline interpolation."""
196+
197+
n_models = self._fit(
198+
index,
199+
n_jobs=kwargs.pop("n_jobs"),
200+
**kwargs,
201+
)
202+
203+
if index is None: # If no index, just fit the data.
204+
return None
205+
206+
# ToDo
207+
# What are the gtab (and S0 if any) equivalent of PET ?
208+
if n_models == 1:
209+
predicted, _ = _exec_predict(
210+
self._models[0], **(kwargs | {"gtab": gradient, "S0": self._S0})
211+
)
212+
else:
213+
predicted = [None] * n_models
214+
S0 = np.array_split(self._S0, n_models)
215+
216+
# Parallelize process with joblib
217+
with Parallel(n_jobs=n_models) as executor:
218+
results = executor(
219+
delayed(_exec_predict)(
220+
model,
221+
chunk=i,
222+
**(kwargs | {"gtab": gradient, "S0": S0[i]}),
223+
)
224+
for i, model in enumerate(self._models)
225+
)
226+
for subprediction, index in results:
227+
predicted[index] = subprediction
228+
229+
predicted = np.hstack(predicted)
230+
231+
retval = np.zeros_like(self._data_mask, dtype=self._dataset.dataobj.dtype)
232+
retval[self._data_mask, ...] = predicted
233+
return retval
234+
235+
236+
class BSplinePETModel(BasePETModel):
237+
"""A PET imaging realignment model based on B-Spline approximation."""
238+
239+
__slots__ = (
240+
"_t",
241+
"_order",
242+
"_n_ctrl",
243+
)
244+
245+
def __init__(
246+
self,
247+
dataset: PET,
248+
n_ctrl: int = None,
249+
order: int = 3,
250+
**kwargs,
251+
):
252+
"""Create the B-Spline interpolating matrix.
253+
254+
Parameters
255+
----------
256+
n_ctrl : :obj:`int`
257+
Number of B-Spline control points. If `None`, then one control point every
258+
six timepoints will be used. The less control points, the smoother is the
259+
model.
260+
order : :obj:`int`
261+
Order of the B-Spline approximation.
262+
"""
263+
264+
super().__init__(dataset, **kwargs)
265+
266+
self._order = order
267+
268+
# Calculate index coordinates in the B-Spline grid
269+
self._n_ctrl = n_ctrl or (len(self._x) // 4) + 1
270+
271+
# B-Spline knots
272+
self._t = np.arange(-3, self._n_ctrl + 4, dtype="float32")
273+
274+
275+
@property
276+
def is_fitted(self):
277+
return self._locked_fit is not None
278+
279+
def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
280+
"""Fit the model."""
281+
282+
if self._locked_fit is not None:
283+
return n_jobs
284+
285+
if index is not None:
286+
raise NotImplementedError("Fitting with held-out data is not supported")
287+
288+
# ToDo
289+
# Does not make sense to make timepoints be a kwarg if it is provided as a named parameter to __init__
290+
timepoints = kwargs.get("timepoints", None) or self._x
291+
x = (np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl
292+
293+
data = self._dataset.dataobj
294+
138295
# A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding)
139296
A = BSpline.design_matrix(x, self._t, k=self._order)
140297
AT = A.T
@@ -149,6 +306,12 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
149306
def fit_predict(self, index: int | None = None, **kwargs):
150307
"""Return the corrected volume using B-spline interpolation."""
151308

309+
# ToDo
310+
# Does the below apply to PET ? Martin has the return None statement
311+
# if index is None:
312+
# raise RuntimeError(
313+
# f"Model {self.__class__.__name__} does not allow locking.")
314+
152315
# Fit the BSpline basis on all data
153316
if self._locked_fit is None:
154317
self._fit(index, n_jobs=kwargs.pop("n_jobs", None), **kwargs)

0 commit comments

Comments
 (0)