2222#
2323"""Models for nuclear imaging."""
2424
25+ from abc import ABC , ABCMeta , abstractmethod
2526from os import cpu_count
27+ from typing import Union
2628
2729import nibabel as nb
2830import numpy as np
3133from scipy .interpolate import BSpline
3234from scipy .sparse .linalg import cg
3335
36+ from nifreeze .data .pet import PET
3437from nifreeze .model .base import BaseModel
3538
3639DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
3740"""Time frame tolerance in seconds."""
3841
3942
40- class PETModel ( BaseModel ):
41- """A PET imaging realignment model based on B-Spline approximation."""
43+ def _exec_fit ( model , data , chunk = None , ** kwargs ):
44+ return model . fit ( data , ** kwargs ), chunk
4245
43- __slots__ = (
44- "_t" ,
45- "_x" ,
46- "_xlim" ,
47- "_order" ,
48- "_n_ctrl" ,
49- "_datashape" ,
50- "_mask" ,
51- "_smooth_fwhm" ,
52- "_thresh_pct" ,
53- )
46+
47+ def _exec_predict (model , chunk = None , ** kwargs ):
48+ """Propagate model parameters and call predict."""
49+ return np .squeeze (model .predict (** kwargs )), chunk
50+
51+
52+ class BasePETModel (BaseModel , ABC ):
53+ """Interface and default methods for PET models."""
54+
55+ __metaclass__ = ABCMeta
56+
57+ __slots__ = {
58+ "_data_mask" : "A mask for the voxels that will be fitted and predicted" ,
59+ "_x" : "" ,
60+ "_xlim" : "" ,
61+ "_smooth_fwhm" : "FWHM in mm over which to smooth" ,
62+ "_thresh_pct" : "Thresholding percentile for the signal" ,
63+ "_model_class" : "Defining a model class" ,
64+ "_modelargs" : "Arguments acceptable by the underlying model" ,
65+ "_models" : "List with one or more (if parallel execution) model instances" ,
66+ }
5467
5568 def __init__ (
5669 self ,
57- dataset ,
58- timepoints = None ,
59- xlim = None ,
60- n_ctrl = None ,
61- order = 3 ,
62- smooth_fwhm = 10 ,
63- thresh_pct = 20 ,
70+ dataset : PET ,
71+ timepoints : list | np .ndarray | None = None , ## Is there a way to use array-like
72+ xlim : list | np .ndarray | None = None ,
73+ smooth_fwhm : float = 10.0 ,
74+ thresh_pct : float = 20.0 ,
6475 ** kwargs ,
6576 ):
66- """
67- Create the B-Spline interpolating matrix.
77+ """Initialization.
6878
69- Parameters:
70- -----------
71- timepoints : :obj:`list`
79+ Parameters
80+ ----------
81+ timepoints : :obj:`list` or :obj:`~np.ndarray`
7282 The timing (in sec) of each PET volume.
7383 E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330.,
7484 420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]``
75-
76- n_ctrl : :obj:`int`
77- Number of B-Spline control points. If `None`, then one control point every
78- six timepoints will be used. The less control points, the smoother is the
79- model.
80-
85+ xlim : .
86+ .
87+ smooth_fwhm : obj:`float`
88+ FWHM in mm over which to smooth the signal.
89+ thresh_pct : obj:`float`
90+ Thresholding percentile for the signal.
8191 """
92+
8293 super ().__init__ (dataset , ** kwargs )
8394
95+ # Duck typing, instead of explicitly testing for PET type
96+ if not hasattr (dataset , "total_duration" ):
97+ raise TypeError ("Dataset MUST be a PET object." )
98+
99+ if not hasattr (dataset , "midframe" ):
100+ raise ValueError ("Dataset MUST have a midframe." )
101+
102+ # ToDO
103+ # Are the timepoints your "gradients" ??? If so, can they be computed
104+ # from frame_time or frame_duration
105+ # Or else frame_time and frame_duration ????
106+
107+ self ._data_mask = (
108+ dataset .brainmask
109+ if dataset .brainmask is not None
110+ else np .ones (dataset .dataobj .shape [:3 ], dtype = bool )
111+ )
112+
113+ # ToDo
114+ # Are timepoints and xlim features that all PET models require ??
84115 if timepoints is None or xlim is None :
85- raise TypeError ( " timepoints must be provided in initialization " )
116+ raise ValueError ( "` timepoints` and `xlim` must be specified and have a nonzero value. " )
86117
87- self ._order = order
88- self ._x = np .array (timepoints , dtype = "float32" )
89- self ._xlim = xlim
118+ self ._x = np .asarray (timepoints , dtype = "float32" )
119+ self ._xlim = np .asarray (xlim )
90120 self ._smooth_fwhm = smooth_fwhm
91121 self ._thresh_pct = thresh_pct
92122
@@ -95,62 +125,114 @@ def __init__(
95125 if self ._x [- 1 ] > (self ._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL ):
96126 raise ValueError ("Last frame midpoint should not be equal or greater than duration" )
97127
98- # Calculate index coordinates in the B-Spline grid
99- self ._n_ctrl = n_ctrl or (len (timepoints ) // 4 ) + 1
128+ def _preproces_data (self ) -> np .ndarray :
129+ # ToDo
130+ # data, _, gtab = self._dataset[idxmask] ### This needs the PET data model to be changed
131+ data = self ._dataset .dataobj
132+ brainmask = self ._dataset .brainmask
100133
101- # B-Spline knots
102- self ._t = np .arange (- 3 , float (self ._n_ctrl ) + 4 , dtype = "float32" )
134+ # Preprocess the data
135+ if self ._smooth_fwhm > 0 :
136+ smoothed_img = smooth_image (
137+ nb .Nifti1Image (data , self ._dataset .affine ), self ._smooth_fwhm
138+ )
139+ data = smoothed_img .get_fdata ()
140+
141+ if self ._thresh_pct > 0 :
142+ thresh_val = np .percentile (data , self ._thresh_pct )
143+ data [data < thresh_val ] = 0
103144
104- self . _datashape = None
105- self . _mask = None
145+ # Convert data into V (voxels) x T (timepoints)
146+ return data . reshape (( - 1 , data . shape [ - 1 ])) if brainmask is None else data [ brainmask ]
106147
107148 @property
108- def is_fitted (self ):
149+ def is_fitted (self ) -> bool :
109150 return self ._locked_fit is not None
110151
152+ @abstractmethod
153+ def fit_predict (self , index : int | None = None , ** kwargs ) -> Union [np .ndarray , None ]:
154+ """Predict the corrected volume."""
155+ return None
156+
157+
158+ class BSplinePETModel (BasePETModel ):
159+ """A PET imaging realignment model based on B-Spline approximation."""
160+
161+ __slots__ = (
162+ "_t" ,
163+ "_order" ,
164+ "_n_ctrl" ,
165+ )
166+
167+ def __init__ (
168+ self ,
169+ dataset : PET ,
170+ n_ctrl : int | None = None ,
171+ order : int = 3 ,
172+ ** kwargs ,
173+ ):
174+ """Create the B-Spline interpolating matrix.
175+
176+ Parameters
177+ ----------
178+ n_ctrl : :obj:`int`
179+ Number of B-Spline control points. If `None`, then one control point every
180+ six timepoints will be used. The less control points, the smoother is the
181+ model.
182+ order : :obj:`int`
183+ Order of the B-Spline approximation.
184+ """
185+
186+ super ().__init__ (dataset , ** kwargs )
187+
188+ self ._order = order
189+
190+ # Calculate index coordinates in the B-Spline grid
191+ self ._n_ctrl = n_ctrl or (len (self ._x ) // 4 ) + 1
192+
193+ # B-Spline knots
194+ self ._t = np .arange (- 3 , self ._n_ctrl + 4 , dtype = "float32" )
195+
111196 def _fit (self , index : int | None = None , n_jobs = None , ** kwargs ) -> int :
112197 """Fit the model."""
113198
199+ n_jobs = n_jobs or min (cpu_count () or 1 , 8 )
200+
114201 if self ._locked_fit is not None :
115202 return n_jobs
116203
117204 if index is not None :
118205 raise NotImplementedError ("Fitting with held-out data is not supported" )
119- timepoints = kwargs .get ("timepoints" , None ) or self ._x
120- x = (np .array (timepoints , dtype = "float32" ) / self ._xlim ) * self ._n_ctrl
121206
122- data = self ._dataset .dataobj
123- brainmask = self ._dataset .brainmask
207+ data = self ._preproces_data ()
124208
125- if self ._smooth_fwhm > 0 :
126- smoothed_img = smooth_image (
127- nb .Nifti1Image (data , self ._dataset .affine ), self ._smooth_fwhm
128- )
129- data = smoothed_img .get_fdata ()
130-
131- if self ._thresh_pct > 0 :
132- thresh_val = np .percentile (data , self ._thresh_pct )
133- data [data < thresh_val ] = 0
134-
135- # Convert data into V (voxels) x T (timepoints)
136- data = data .reshape ((- 1 , data .shape [- 1 ])) if brainmask is None else data [brainmask ]
209+ # ToDo
210+ # Does not make sense to make timepoints be a kwarg if it is provided as a named parameter to __init__
211+ timepoints = kwargs .get ("timepoints" , None ) or self ._x
212+ x = np .asarray (timepoints , dtype = "float32" ) / self ._xlim * self ._n_ctrl
137213
138214 # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding)
139215 A = BSpline .design_matrix (x , self ._t , k = self ._order )
140216 AT = A .T
141217 ATdotA = AT @ A
142218
143219 # Parallelize process with joblib
144- with Parallel (n_jobs = n_jobs or min ( cpu_count () or 1 , 8 ) ) as executor :
220+ with Parallel (n_jobs = n_jobs ) as executor :
145221 results = executor (delayed (cg )(ATdotA , AT @ v ) for v in data )
146222
147- self ._locked_fit = np .array ([r [0 ] for r in results ])
223+ self ._locked_fit = np .asarray ([r [0 ] for r in results ])
148224
149225 return n_jobs
150226
151- def fit_predict (self , index : int | None = None , ** kwargs ):
227+ def fit_predict (self , index : int | None = None , ** kwargs ) -> Union [ np . ndarray , None ] :
152228 """Return the corrected volume using B-spline interpolation."""
153229
230+ # ToDo
231+ # Does the below apply to PET ? Martin has the return None statement
232+ # if index is None:
233+ # raise RuntimeError(
234+ # f"Model {self.__class__.__name__} does not allow locking.")
235+
154236 # Fit the BSpline basis on all data
155237 if self ._locked_fit is None :
156238 self ._fit (index , n_jobs = kwargs .pop ("n_jobs" , None ), ** kwargs )
0 commit comments