3838DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
3939"""Time frame tolerance in seconds."""
4040
41+ START_INDEX_RANGE_ERROR_MSG = "start_index must be within the range of provided timepoints."
42+ """PET model fitting start index allowed values error."""
43+
44+ FIT_INDEX_OUT_OF_RANGE_ERROR_MSG = "Index out of range for available timepoints."
45+ """PET modl fitting index out-of-range error"""
46+
4147
4248class PETModel (BaseModel ):
4349 """A PET imaging realignment model based on B-Spline approximation."""
@@ -52,6 +58,8 @@ class PETModel(BaseModel):
5258 "_mask" ,
5359 "_smooth_fwhm" ,
5460 "_thresh_pct" ,
61+ "_start_index" ,
62+ "_start_time" ,
5563 )
5664
5765 def __init__ (
@@ -63,6 +71,7 @@ def __init__(
6371 order : int = 3 ,
6472 smooth_fwhm : float = 10.0 ,
6573 thresh_pct : float = 20.0 ,
74+ start_index : int | None = None ,
6675 ** kwargs ,
6776 ):
6877 """
@@ -80,6 +89,11 @@ def __init__(
8089 six timepoints will be used. The less control points, the smoother is the
8190 model.
8291
92+ start_index : :obj:`int` or None
93+ If provided, the model will be fitted using only timepoints starting from
94+ this index (inclusive). Predictions for timepoints earlier than the
95+ specified start will reuse the predicted volume for the start timepoint.
96+
8397 """
8498 super ().__init__ (dataset , ** kwargs )
8599
@@ -97,6 +111,15 @@ def __init__(
97111 if self ._x [- 1 ] > (self ._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL ):
98112 raise ValueError ("Last frame midpoint should not be equal or greater than duration" )
99113
114+ # Validate and store start index / time
115+ if start_index is None :
116+ self ._start_index = 0
117+ else :
118+ if start_index < 0 or start_index >= len (self ._x ):
119+ raise ValueError (START_INDEX_RANGE_ERROR_MSG )
120+ self ._start_index = start_index
121+ self ._start_time = float (self ._x [self ._start_index ])
122+
100123 # Calculate index coordinates in the B-Spline grid
101124 self ._n_ctrl = n_ctrl or (len (timepoints ) // 4 ) + 1
102125
@@ -119,7 +142,9 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
119142 if index is not None :
120143 raise NotImplementedError ("Fitting with held-out data is not supported" )
121144 timepoints = kwargs .get ("timepoints" , None ) or self ._x
122- x = np .asarray ((np .array (timepoints , dtype = "float32" ) / self ._xlim ) * self ._n_ctrl )
145+ timepoints_to_fit = np .asarray (timepoints , dtype = "float32" )[self ._start_index :]
146+
147+ x = np .asarray ((np .array (timepoints_to_fit ) / self ._xlim ) * self ._n_ctrl )
123148
124149 data = self ._dataset .dataobj
125150 brainmask = self ._dataset .brainmask
@@ -137,6 +162,11 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
137162 # Convert data into V (voxels) x T (timepoints)
138163 data = data .reshape ((- 1 , data .shape [- 1 ])) if brainmask is None else data [brainmask ]
139164
165+ # If fitting started later than the first frame, drop early columns so the
166+ # temporal length matches timepoints_to_fit
167+ if self ._start_index > 0 :
168+ data = data [:, self ._start_index :]
169+
140170 # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding)
141171 A = BSpline .design_matrix (x , self ._t , k = self ._order )
142172 AT = A .T
@@ -151,7 +181,12 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
151181 return n_jobs
152182
153183 def fit_predict (self , index : int | None = None , ** kwargs ) -> Union [np .ndarray , None ]:
154- """Return the corrected volume using B-spline interpolation."""
184+ """Return the corrected volume using B-spline interpolation.
185+
186+ Predictions for times earlier than the configured start_time will return
187+ the prediction for the start_time (i.e., transforms estimated for the
188+ start are reused for earlier low-SNR frames).
189+ """
155190
156191 # Fit the BSpline basis on all data
157192 if self ._locked_fit is None :
@@ -164,8 +199,22 @@ def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, N
164199 if index is None : # If no index, just fit the data.
165200 return None
166201
202+ # Map integer indices to actual timepoints if needed
203+ if isinstance (index , (int , np .integer )):
204+ idx_int = int (index )
205+ if idx_int < 0 or idx_int >= len (self ._x ):
206+ raise IndexError (FIT_INDEX_OUT_OF_RANGE_ERROR_MSG )
207+ index_time = float (self ._x [idx_int ])
208+ else :
209+ index_time = float (index )
210+
211+ # If the requested time is earlier than the configured start time, use the
212+ # start time's prediction (reuse the transforms estimated for start)
213+ if index_time < self ._start_time :
214+ index_time = self ._start_time
215+
167216 # Project sample timing into B-Spline coordinates
168- x = np .asarray ((index / self ._xlim ) * self ._n_ctrl )
217+ x = np .asarray ((index_time / self ._xlim ) * self ._n_ctrl )
169218 A = BSpline .design_matrix (x , self ._t , k = self ._order )
170219
171220 # A is 1 (num. timepoints) x C (num. coeff)
0 commit comments