2222#
2323"""Models for nuclear imaging."""
2424
25+ import abc
26+ from abc import ABC
2527from os import cpu_count
28+ from typing import Union
2629
2730import nibabel as nb
2831import numpy as np
3134from scipy .interpolate import BSpline
3235from scipy .sparse .linalg import cg
3336
37+ from nifreeze .data .pet import PET
3438from nifreeze .model .base import BaseModel
3539
3640DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
3741"""Time frame tolerance in seconds."""
3842
3943
40- class PETModel ( BaseModel ):
41- """A PET imaging realignment model based on B-Spline approximation."""
44+ def _exec_fit ( model , data , chunk = None , ** kwargs ):
45+ return model . fit ( data , ** kwargs ), chunk
4246
43- __slots__ = (
44- "_t" ,
45- "_x" ,
46- "_xlim" ,
47- "_order" ,
48- "_n_ctrl" ,
49- "_datashape" ,
50- "_mask" ,
51- "_smooth_fwhm" ,
52- "_thresh_pct" ,
53- )
47+
48+ def _exec_predict (model , chunk = None , ** kwargs ):
49+ """Propagate model parameters and call predict."""
50+ return np .squeeze (model .predict (** kwargs )), chunk
51+
52+
53+ class BasePETModel (BaseModel , ABC ):
54+ """Interface and default methods for PET models."""
55+
56+ __metaclass__ = abc .ABCMeta
57+
58+ __slots__ = {
59+ "_data_mask" : "A mask for the voxels that will be fitted and predicted" ,
60+ "_x" : "" ,
61+ "_xlim" : "" ,
62+ "_smooth_fwhm" : "FWHM in mm over which to smooth" ,
63+ "_thresh_pct" : "Thresholding percentile for the signal" ,
64+ "_model_class" : "Defining a model class" ,
65+ "_modelargs" : "Arguments acceptable by the underlying model" ,
66+ "_models" : "List with one or more (if parallel execution) model instances" ,
67+ }
5468
5569 def __init__ (
5670 self ,
57- dataset ,
58- timepoints = None ,
59- xlim = None ,
60- n_ctrl = None ,
61- order = 3 ,
62- smooth_fwhm = 10 ,
63- thresh_pct = 20 ,
71+ dataset : PET ,
72+ timepoints : list | np .ndarray = None , ## Is there a way to use array-like
73+ xlim : list | np .ndarray = None ,
74+ smooth_fwhm : float = 10.0 ,
75+ thresh_pct : float = 20.0 ,
6476 ** kwargs ,
6577 ):
66- """
67- Create the B-Spline interpolating matrix.
78+ """Initialization.
6879
69- Parameters:
70- -----------
71- timepoints : :obj:`list`
80+ Parameters
81+ ----------
82+ timepoints : :obj:`list` or :obj:`~np.ndarray`
7283 The timing (in sec) of each PET volume.
7384 E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330.,
7485 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]``
75-
76- n_ctrl : :obj:`int`
77- Number of B-Spline control points. If `None`, then one control point every
78- six timepoints will be used. The less control points, the smoother is the
79- model.
80-
86+ xlim : .
87+ .
88+ smooth_fwhm : obj:`float`
89+ FWHM in mm over which to smooth the signal.
90+ thresh_pct : obj:`float`
91+ Thresholding percentile for the signal.
8192 """
93+
8294 super ().__init__ (dataset , ** kwargs )
8395
96+ # Duck typing, instead of explicitly testing for PET type
97+ if not hasattr (dataset , "total_duration" ):
98+ raise TypeError ("Dataset MUST be a PET object." )
99+
100+ if not hasattr (dataset , "midframe" ):
101+ raise ValueError ("Dataset MUST have a midframe." )
102+
103+ # ToDO
104+ # Are the timepoints your "gradients" ??? If so, can they be computed
105+ # from frame_time or frame_duration
106+ # Or else frame_time and frame_duration ????
107+
108+ self ._data_mask = (
109+ dataset .brainmask
110+ if dataset .brainmask is not None
111+ else np .ones (dataset .dataobj .shape [:3 ], dtype = bool )
112+ )
113+
114+ # ToDo
115+ # Are timepoints and xlim features that all PET models require ??
84116 if timepoints is None or xlim is None :
85- raise TypeError ( " timepoints must be provided in initialization " )
117+ raise ValueError ( "` timepoints` and `xlim` must be specified and have a nonzero value. " )
86118
87- self ._order = order
88119 self ._x = np .array (timepoints , dtype = "float32" )
89120 self ._xlim = xlim
90121 self ._smooth_fwhm = smooth_fwhm
@@ -95,33 +126,15 @@ def __init__(
95126 if self ._x [- 1 ] > (self ._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL ):
96127 raise ValueError ("Last frame midpoint should not be equal or greater than duration" )
97128
98- # Calculate index coordinates in the B-Spline grid
99- self ._n_ctrl = n_ctrl or (len (timepoints ) // 4 ) + 1
100-
101- # B-Spline knots
102- self ._t = np .arange (- 3 , float (self ._n_ctrl ) + 4 , dtype = "float32" )
103-
104- self ._datashape = None
105- self ._mask = None
106-
107- @property
108- def is_fitted (self ):
109- return self ._locked_fit is not None
110-
111- def _fit (self , index : int | None = None , n_jobs = None , ** kwargs ):
112- """Fit the model."""
113-
114- if self ._locked_fit is not None :
115- return n_jobs
116-
117- if index is not None :
118- raise NotImplementedError ("Fitting with held-out data is not supported" )
119- timepoints = kwargs .get ("timepoints" , None ) or self ._x
120- x = (np .array (timepoints , dtype = "float32" ) / self ._xlim ) * self ._n_ctrl
129+ super ().__init__ (dataset , ** kwargs )
121130
131+ def _preproces_data (self ) -> np .ndarray :
132+ # ToDo
133+ # data, _, gtab = self._dataset[idxmask] ### This needs the PET data model to be changed
122134 data = self ._dataset .dataobj
123135 brainmask = self ._dataset .brainmask
124136
137+ # Preprocess the data
125138 if self ._smooth_fwhm > 0 :
126139 smoothed_img = smooth_image (
127140 nb .Nifti1Image (data , self ._dataset .affine ), self ._smooth_fwhm
@@ -133,7 +146,73 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
133146 data [data < thresh_val ] = 0
134147
135148 # Convert data into V (voxels) x T (timepoints)
136- data = data .reshape ((- 1 , data .shape [- 1 ])) if brainmask is None else data [brainmask ]
149+ return data .reshape ((- 1 , data .shape [- 1 ])) if brainmask is None else data [brainmask ]
150+
151+ @property
152+ def is_fitted (self ) -> bool :
153+ return self ._locked_fit is not None
154+
155+ @abc .abstractmethod
156+ def fit_predict (self , index : int | None = None , ** kwargs ) -> Union [np .ndarray , None ]:
157+ """Predict the corrected volume."""
158+ return
159+
160+
161+ class BSplinePETModel (BasePETModel ):
162+ """A PET imaging realignment model based on B-Spline approximation."""
163+
164+ __slots__ = (
165+ "_t" ,
166+ "_order" ,
167+ "_n_ctrl" ,
168+ )
169+
170+ def __init__ (
171+ self ,
172+ dataset : PET ,
173+ n_ctrl : int = None ,
174+ order : int = 3 ,
175+ ** kwargs ,
176+ ):
177+ """Create the B-Spline interpolating matrix.
178+
179+ Parameters
180+ ----------
181+ n_ctrl : :obj:`int`
182+ Number of B-Spline control points. If `None`, then one control point every
183+ six timepoints will be used. The less control points, the smoother is the
184+ model.
185+ order : :obj:`int`
186+ Order of the B-Spline approximation.
187+ """
188+
189+ super ().__init__ (dataset , ** kwargs )
190+
191+ self ._order = order
192+
193+ # Calculate index coordinates in the B-Spline grid
194+ self ._n_ctrl = n_ctrl or (len (self ._x ) // 4 ) + 1
195+
196+ # B-Spline knots
197+ self ._t = np .arange (- 3 , self ._n_ctrl + 4 , dtype = "float32" )
198+
199+ def _fit (self , index : int | None = None , n_jobs = None , ** kwargs ) -> Union [int , None ]:
200+ """Fit the model."""
201+
202+ n_jobs = n_jobs or 1
203+
204+ if self ._locked_fit is not None :
205+ return n_jobs
206+
207+ if index is not None :
208+ raise NotImplementedError ("Fitting with held-out data is not supported" )
209+
210+ data = self ._preproces_data ()
211+
212+ # ToDo
213+ # Does not make sense to make timepoints be a kwarg if it is provided as a named parameter to __init__
214+ timepoints = kwargs .get ("timepoints" , None ) or self ._x
215+ x = (np .array (timepoints , dtype = "float32" ) / self ._xlim ) * self ._n_ctrl
137216
138217 # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding)
139218 A = BSpline .design_matrix (x , self ._t , k = self ._order )
@@ -146,9 +225,15 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs):
146225
147226 self ._locked_fit = np .array ([r [0 ] for r in results ])
148227
149- def fit_predict (self , index : int | None = None , ** kwargs ):
228+ def fit_predict (self , index : int | None = None , ** kwargs ) -> Union [ np . ndarray , None ] :
150229 """Return the corrected volume using B-spline interpolation."""
151230
231+ # ToDo
232+ # Does the below apply to PET ? Martin has the return None statement
233+ # if index is None:
234+ # raise RuntimeError(
235+ # f"Model {self.__class__.__name__} does not allow locking.")
236+
152237 # Fit the BSpline basis on all data
153238 if self ._locked_fit is None :
154239 self ._fit (index , n_jobs = kwargs .pop ("n_jobs" , None ), ** kwargs )
0 commit comments