3838DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
3939"""Time frame tolerance in seconds."""
4040
41+ START_INDEX_RANGE_ERROR_MSG = "start_index must be within the range of provided timepoints."
42+ """PET model fitting start index allowed values error."""
43+
44+ FIT_INDEX_OUT_OF_RANGE_ERROR_MSG = "Index out of range for available timepoints."
45+ """PET model fitting index out-of-range error"""
46+
4147
4248class PETModel (BaseModel ):
4349 """A PET imaging realignment model based on B-Spline approximation."""
@@ -52,6 +58,8 @@ class PETModel(BaseModel):
5258 "_mask" ,
5359 "_smooth_fwhm" ,
5460 "_thresh_pct" ,
61+ "_start_index" ,
62+ "_start_time" ,
5563 )
5664
5765 def __init__ (
@@ -63,6 +71,7 @@ def __init__(
6371 order : int = 3 ,
6472 smooth_fwhm : float = 10.0 ,
6573 thresh_pct : float = 20.0 ,
74+ start_index : int | None = None ,
6675 ** kwargs ,
6776 ):
6877 """
@@ -80,6 +89,14 @@ def __init__(
8089 six timepoints will be used. The less control points, the smoother is the
8190 model.
8291
92+ start_index : :obj:`int` or None
93+ If provided, the model will be fitted using only timepoints starting from
94+ this index (inclusive). Predictions for timepoints earlier than the
95+ specified start will reuse the predicted volume for the start timepoint.
96+ This is useful, for example, to discard a number of frames at the
97+ beginning of the sequence, which due to their little SNR may impact
98+ registration negatively.
99+
83100 """
84101 super ().__init__ (dataset , ** kwargs )
85102
@@ -97,6 +114,15 @@ def __init__(
97114 if self ._x [- 1 ] > (self ._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL ):
98115 raise ValueError ("Last frame midpoint should not be equal or greater than duration" )
99116
117+ # Validate and store start index / time
118+ if start_index is None :
119+ self ._start_index = 0
120+ else :
121+ if start_index < 0 or start_index >= len (self ._x ):
122+ raise ValueError (START_INDEX_RANGE_ERROR_MSG )
123+ self ._start_index = start_index
124+ self ._start_time = float (self ._x [self ._start_index ])
125+
100126 # Calculate index coordinates in the B-Spline grid
101127 self ._n_ctrl = n_ctrl or (len (timepoints ) // 4 ) + 1
102128
@@ -119,7 +145,9 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
119145 if index is not None :
120146 raise NotImplementedError ("Fitting with held-out data is not supported" )
121147 timepoints = kwargs .get ("timepoints" , None ) or self ._x
122- x = np .asarray ((np .array (timepoints , dtype = "float32" ) / self ._xlim ) * self ._n_ctrl )
148+ timepoints_to_fit = np .asarray (timepoints , dtype = "float32" )[self ._start_index :]
149+
150+ x = np .asarray ((np .array (timepoints_to_fit ) / self ._xlim ) * self ._n_ctrl )
123151
124152 data = self ._dataset .dataobj
125153 brainmask = self ._dataset .brainmask
@@ -137,6 +165,11 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
137165 # Convert data into V (voxels) x T (timepoints)
138166 data = data .reshape ((- 1 , data .shape [- 1 ])) if brainmask is None else data [brainmask ]
139167
168+ # If fitting started later than the first frame, drop early columns so the
169+ # temporal length matches timepoints_to_fit
170+ if self ._start_index > 0 :
171+ data = data [:, self ._start_index :]
172+
140173 # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding)
141174 A = BSpline .design_matrix (x , self ._t , k = self ._order )
142175 AT = A .T
@@ -151,7 +184,12 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
151184 return n_jobs
152185
153186 def fit_predict (self , index : int | None = None , ** kwargs ) -> Union [np .ndarray , None ]:
154- """Return the corrected volume using B-spline interpolation."""
187+ """Return the corrected volume using B-spline interpolation.
188+
189+ Predictions for times earlier than the configured start_time will return
190+ the prediction for the start_time (i.e., transforms estimated for the
191+ start are reused for earlier low-SNR frames).
192+ """
155193
156194 # Fit the BSpline basis on all data
157195 if self ._locked_fit is None :
@@ -164,8 +202,22 @@ def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, N
164202 if index is None : # If no index, just fit the data.
165203 return None
166204
205+ # Map integer indices to actual timepoints if needed
206+ if isinstance (index , (int , np .integer )):
207+ idx_int = int (index )
208+ if idx_int < 0 or idx_int >= len (self ._x ):
209+ raise IndexError (FIT_INDEX_OUT_OF_RANGE_ERROR_MSG )
210+ index_time = float (self ._x [idx_int ])
211+ else :
212+ index_time = float (index )
213+
214+ # If the requested time is earlier than the configured start time, use the
215+ # start time's prediction (reuse the transforms estimated for start)
216+ if index_time < self ._start_time :
217+ index_time = self ._start_time
218+
167219 # Project sample timing into B-Spline coordinates
168- x = np .asarray ((index / self ._xlim ) * self ._n_ctrl )
220+ x = np .asarray ((index_time / self ._xlim ) * self ._n_ctrl )
169221 A = BSpline .design_matrix (x , self ._t , k = self ._order )
170222
171223 # A is 1 (num. timepoints) x C (num. coeff)
0 commit comments