2222#
2323"""Models for nuclear imaging."""
2424
25+ from importlib import import_module
2526from os import cpu_count
27+ from typing import Any
2628
2729import nibabel as nb
2830import numpy as np
3133from scipy .interpolate import BSpline
3234from scipy .sparse .linalg import cg
3335
36+ from nifreeze .data .pet import PET
3437from nifreeze .model .base import BaseModel
3538
3639DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
3740"""Time frame tolerance in seconds."""
3841
3942
40- class PETModel ( BaseModel ):
41- """A PET imaging realignment model based on B-Spline approximation."""
43+ def _exec_fit ( model , data , chunk = None , ** kwargs ):
44+ return model . fit ( data , ** kwargs ), chunk
4245
43- __slots__ = (
44- "_t" ,
45- "_x" ,
46- "_xlim" ,
47- "_order" ,
48- "_n_ctrl" ,
49- "_datashape" ,
50- "_mask" ,
51- "_smooth_fwhm" ,
52- "_thresh_pct" ,
53- )
46+
47+ def _exec_predict (model , chunk = None , ** kwargs ):
48+ """Propagate model parameters and call predict."""
49+ return np .squeeze (model .predict (** kwargs )), chunk
50+
51+
52+ class BasePETModel (BaseModel ):
53+ """Interface and default methods for PET models."""
54+
55+ __slots__ = {
56+ "_data_mask" : "A mask for the voxels that will be fitted and predicted" ,
57+ "_x" : "" ,
58+ "_xlim" : "" ,
59+ "_smooth_fwhm" : "FWHM in mm over which to smooth" ,
60+ "_thresh_pct" : "Thresholding percentile for the signal" ,
61+ "_model_class" : "Defining a model class" ,
62+ "_modelargs" : "Arguments acceptable by the underlying model" ,
63+ "_models" : "List with one or more (if parallel execution) model instances" ,
64+ }
5465
5566 def __init__ (
5667 self ,
57- dataset ,
58- timepoints = None ,
59- xlim = None ,
60- n_ctrl = None ,
61- order = 3 ,
62- smooth_fwhm = 10 ,
63- thresh_pct = 20 ,
68+ dataset : PET ,
69+ timepoints : list | np .ndarray = None , ## Is there a way to use array-like
70+ xlim : list | np .ndarray = None ,
71+ smooth_fwhm : float = 10.0 ,
72+ thresh_pct : float = 20.0 ,
6473 ** kwargs ,
6574 ):
66- """
67- Create the B-Spline interpolating matrix.
75+ """Initialization.
6876
69- Parameters:
70- -----------
71- timepoints : :obj:`list`
77+ Parameters
78+ ----------
79+ timepoints : :obj:`list` or :obj:`~np.ndarray`
7280 The timing (in sec) of each PET volume.
7381 E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330.,
7482 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]``
75-
76- n_ctrl : :obj:`int`
77- Number of B-Spline control points. If `None`, then one control point every
78- six timepoints will be used. The less control points, the smoother is the
79- model.
80-
83+ xlim : .
84+ .
85+ smooth_fwhm : obj:`float`
86+ FWHM in mm over which to smooth the signal.
87+ thresh_pct : obj:`float`
88+ Thresholding percentile for the signal.
8189 """
90+
8291 super ().__init__ (dataset , ** kwargs )
8392
93+ # Duck typing, instead of explicitly testing for PET type
94+ if not hasattr (dataset , "total_duration" ):
95+ raise TypeError ("Dataset MUST be a PET object." )
96+
97+ if not hasattr (dataset , "midframe" ):
98+ raise ValueError ("Dataset MUST have a midframe." )
99+
100+ # ToDO
101+ # Are the timepoints your "gradients" ??? If so, can they be computed
102+ # from frame_time or frame_duration
103+ # Or else frame_time and frame_duration ????
104+
105+ self ._data_mask = (
106+ dataset .brainmask
107+ if dataset .brainmask is not None
108+ else np .ones (dataset .dataobj .shape [:3 ], dtype = bool )
109+ )
110+
111+ # ToDo
112+ # Are timepoints and xlim features that all PET models require ??
84113 if timepoints is None or xlim is None :
85- raise TypeError ( " timepoints must be provided in initialization " )
114+ raise ValueError ( "` timepoints` and `xlim` must be specified and have a nonzero value. " )
86115
87- self ._order = order
88116 self ._x = np .array (timepoints , dtype = "float32" )
89117 self ._xlim = xlim
90118 self ._smooth_fwhm = smooth_fwhm
@@ -95,14 +123,7 @@ def __init__(
95123 if self ._x [- 1 ] > (self ._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL ):
96124 raise ValueError ("Last frame midpoint should not be equal or greater than duration" )
97125
98- # Calculate index coordinates in the B-Spline grid
99- self ._n_ctrl = n_ctrl or (len (timepoints ) // 4 ) + 1
100-
101- # B-Spline knots
102- self ._t = np .arange (- 3 , float (self ._n_ctrl ) + 4 , dtype = "float32" )
103-
104- self ._datashape = None
105- self ._mask = None
126+ super ().__init__ (dataset , ** kwargs )
106127
107128 @property
108129 def is_fitted (self ):
@@ -111,17 +132,24 @@ def is_fitted(self):
111132 def _fit (self , index : int | None = None , n_jobs = None , ** kwargs ):
112133 """Fit the model."""
113134
135+ n_jobs = n_jobs or 1
136+
114137 if self ._locked_fit is not None :
115138 return n_jobs
116139
117140 if index is not None :
118141 raise NotImplementedError ("Fitting with held-out data is not supported" )
142+
143+ # ToDo
144+ # Does not make sense to make timepoints be a kwarg if it is provided as a named parameter to __init__
119145 timepoints = kwargs .get ("timepoints" , None ) or self ._x
120- x = (np .array (timepoints , dtype = "float32" ) / self ._xlim ) * self ._n_ctrl
121146
147+ # ToDo
148+ # data, _, gtab = self._dataset[idxmask] ### This needs the PET data model to be changed
122149 data = self ._dataset .dataobj
123150 brainmask = self ._dataset .brainmask
124151
152+ # Preprocess the data
125153 if self ._smooth_fwhm > 0 :
126154 smoothed_img = smooth_image (
127155 nb .Nifti1Image (data , self ._dataset .affine ), self ._smooth_fwhm
@@ -135,6 +163,135 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
135163 # Convert data into V (voxels) x T (timepoints)
136164 data = data .reshape ((- 1 , data .shape [- 1 ])) if brainmask is None else data [brainmask ]
137165
166+ # ToDo
167+ # What is the gtab equivalent of PET ?
168+ model_str = getattr (self , "_model_class" , "" )
169+ module_name , class_name = model_str .rsplit ("." , 1 )
170+ model = getattr (
171+ import_module (module_name ),
172+ class_name ,
173+ )(gtab , ** kwargs )
174+
175+ fit_kwargs : dict [str , Any ] = {} # Add here keyword arguments
176+
177+ # Split data into chunks of group of slices
178+ data_chunks = np .array_split (data , n_jobs )
179+
180+ self ._models = [None ] * n_jobs
181+
182+ # Parallelize process with joblib
183+ with Parallel (n_jobs = n_jobs ) as executor :
184+ results = executor (
185+ delayed (_exec_fit )(model , dchunk , i , ** fit_kwargs )
186+ for i , dchunk in enumerate (data_chunks )
187+ )
188+ for submodel , rindex in results :
189+ self ._models [rindex ] = submodel
190+
191+ return n_jobs
192+
193+
194+ def fit_predict (self , index : int | None = None , ** kwargs ):
195+ """Return the corrected volume using B-spline interpolation."""
196+
197+ n_models = self ._fit (
198+ index ,
199+ n_jobs = kwargs .pop ("n_jobs" ),
200+ ** kwargs ,
201+ )
202+
203+ if index is None : # If no index, just fit the data.
204+ return None
205+
206+ # ToDo
207+ # What are the gtab (and S0 if any) equivalent of PET ?
208+ if n_models == 1 :
209+ predicted , _ = _exec_predict (
210+ self ._models [0 ], ** (kwargs | {"gtab" : gradient , "S0" : self ._S0 })
211+ )
212+ else :
213+ predicted = [None ] * n_models
214+ S0 = np .array_split (self ._S0 , n_models )
215+
216+ # Parallelize process with joblib
217+ with Parallel (n_jobs = n_models ) as executor :
218+ results = executor (
219+ delayed (_exec_predict )(
220+ model ,
221+ chunk = i ,
222+ ** (kwargs | {"gtab" : gradient , "S0" : S0 [i ]}),
223+ )
224+ for i , model in enumerate (self ._models )
225+ )
226+ for subprediction , index in results :
227+ predicted [index ] = subprediction
228+
229+ predicted = np .hstack (predicted )
230+
231+ retval = np .zeros_like (self ._data_mask , dtype = self ._dataset .dataobj .dtype )
232+ retval [self ._data_mask , ...] = predicted
233+ return retval
234+
235+
236+ class BSplinePETModel (BasePETModel ):
237+ """A PET imaging realignment model based on B-Spline approximation."""
238+
239+ __slots__ = (
240+ "_t" ,
241+ "_order" ,
242+ "_n_ctrl" ,
243+ )
244+
245+ def __init__ (
246+ self ,
247+ dataset : PET ,
248+ n_ctrl : int = None ,
249+ order : int = 3 ,
250+ ** kwargs ,
251+ ):
252+ """Create the B-Spline interpolating matrix.
253+
254+ Parameters
255+ ----------
256+ n_ctrl : :obj:`int`
257+ Number of B-Spline control points. If `None`, then one control point every
258+ six timepoints will be used. The less control points, the smoother is the
259+ model.
260+ order : :obj:`int`
261+ Order of the B-Spline approximation.
262+ """
263+
264+ super ().__init__ (dataset , ** kwargs )
265+
266+ self ._order = order
267+
268+ # Calculate index coordinates in the B-Spline grid
269+ self ._n_ctrl = n_ctrl or (len (self ._x ) // 4 ) + 1
270+
271+ # B-Spline knots
272+ self ._t = np .arange (- 3 , self ._n_ctrl + 4 , dtype = "float32" )
273+
274+
275+ @property
276+ def is_fitted (self ):
277+ return self ._locked_fit is not None
278+
279+ def _fit (self , index : int | None = None , n_jobs = None , ** kwargs ):
280+ """Fit the model."""
281+
282+ if self ._locked_fit is not None :
283+ return n_jobs
284+
285+ if index is not None :
286+ raise NotImplementedError ("Fitting with held-out data is not supported" )
287+
288+ # ToDo
289+ # Does not make sense to make timepoints be a kwarg if it is provided as a named parameter to __init__
290+ timepoints = kwargs .get ("timepoints" , None ) or self ._x
291+ x = (np .array (timepoints , dtype = "float32" ) / self ._xlim ) * self ._n_ctrl
292+
293+ data = self ._dataset .dataobj
294+
138295 # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding)
139296 A = BSpline .design_matrix (x , self ._t , k = self ._order )
140297 AT = A .T
@@ -149,6 +306,12 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
149306 def fit_predict (self , index : int | None = None , ** kwargs ):
150307 """Return the corrected volume using B-spline interpolation."""
151308
309+ # ToDo
310+ # Does the below apply to PET ? Martin has the return None statement
311+ # if index is None:
312+ # raise RuntimeError(
313+ # f"Model {self.__class__.__name__} does not allow locking.")
314+
152315 # Fit the BSpline basis on all data
153316 if self ._locked_fit is None :
154317 self ._fit (index , n_jobs = kwargs .pop ("n_jobs" , None ), ** kwargs )
0 commit comments