diff --git a/src/nifreeze/model/pet.py b/src/nifreeze/model/pet.py index 9428bdb91..38d50262e 100644 --- a/src/nifreeze/model/pet.py +++ b/src/nifreeze/model/pet.py @@ -38,6 +38,12 @@ DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2 """Time frame tolerance in seconds.""" +START_INDEX_RANGE_ERROR_MSG = "start_index must be within the range of provided timepoints." +"""PET model fitting start index allowed values error.""" + +FIT_INDEX_OUT_OF_RANGE_ERROR_MSG = "Index out of range for available timepoints." +"""PET model fitting index out-of-range error""" + class PETModel(BaseModel): """A PET imaging realignment model based on B-Spline approximation.""" @@ -52,6 +58,8 @@ class PETModel(BaseModel): "_mask", "_smooth_fwhm", "_thresh_pct", + "_start_index", + "_start_time", ) def __init__( @@ -63,6 +71,7 @@ def __init__( order: int = 3, smooth_fwhm: float = 10.0, thresh_pct: float = 20.0, + start_index: int | None = None, **kwargs, ): """ @@ -80,6 +89,14 @@ def __init__( six timepoints will be used. The less control points, the smoother is the model. + start_index : :obj:`int` or None + If provided, the model will be fitted using only timepoints starting from + this index (inclusive). Predictions for timepoints earlier than the + specified start will reuse the predicted volume for the start timepoint. + This is useful, for example, to discard a number of frames at the + beginning of the sequence, which due to their little SNR may impact + registration negatively. + """ super().__init__(dataset, **kwargs) @@ -97,6 +114,15 @@ def __init__( if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL): raise ValueError("Last frame midpoint should not be equal or greater than duration") + # Validate and store start index / time + if start_index is None: + self._start_index = 0 + else: + if start_index < 0 or start_index >= len(self._x): + raise ValueError(START_INDEX_RANGE_ERROR_MSG) + self._start_index = start_index + self._start_time = float(self._x[self._start_index]) + # Calculate index coordinates in the B-Spline grid self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 @@ -119,7 +145,9 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int: if index is not None: raise NotImplementedError("Fitting with held-out data is not supported") timepoints = kwargs.get("timepoints", None) or self._x - x = np.asarray((np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl) + timepoints_to_fit = np.asarray(timepoints, dtype="float32")[self._start_index :] + + x = np.asarray((np.array(timepoints_to_fit) / self._xlim) * self._n_ctrl) data = self._dataset.dataobj brainmask = self._dataset.brainmask @@ -137,6 +165,11 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int: # Convert data into V (voxels) x T (timepoints) data = data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask] + # If fitting started later than the first frame, drop early columns so the + # temporal length matches timepoints_to_fit + if self._start_index > 0: + data = data[:, self._start_index :] + # A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding) A = BSpline.design_matrix(x, self._t, k=self._order) AT = A.T @@ -151,7 +184,12 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int: return n_jobs def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]: - """Return the corrected volume using B-spline interpolation.""" + """Return the corrected volume using B-spline interpolation. + + Predictions for times earlier than the configured start_time will return + the prediction for the start_time (i.e., transforms estimated for the + start are reused for earlier low-SNR frames). + """ # Fit the BSpline basis on all data if self._locked_fit is None: @@ -164,8 +202,22 @@ def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, N if index is None: # If no index, just fit the data. return None + # Map integer indices to actual timepoints if needed + if isinstance(index, (int, np.integer)): + idx_int = int(index) + if idx_int < 0 or idx_int >= len(self._x): + raise IndexError(FIT_INDEX_OUT_OF_RANGE_ERROR_MSG) + index_time = float(self._x[idx_int]) + else: + index_time = float(index) + + # If the requested time is earlier than the configured start time, use the + # start time's prediction (reuse the transforms estimated for start) + if index_time < self._start_time: + index_time = self._start_time + # Project sample timing into B-Spline coordinates - x = np.asarray((index / self._xlim) * self._n_ctrl) + x = np.asarray((index_time / self._xlim) * self._n_ctrl) A = BSpline.design_matrix(x, self._t, k=self._order) # A is 1 (num. timepoints) x C (num. coeff) diff --git a/test/test_model_pet.py b/test/test_model_pet.py index 6081a07d5..87c247775 100644 --- a/test/test_model_pet.py +++ b/test/test_model_pet.py @@ -25,7 +25,11 @@ import pytest from nifreeze.data.pet import PET -from nifreeze.model.pet import PETModel +from nifreeze.model.pet import ( + FIT_INDEX_OUT_OF_RANGE_ERROR_MSG, + START_INDEX_RANGE_ERROR_MSG, + PETModel, +) @pytest.fixture @@ -81,3 +85,87 @@ def test_petmodel_time_check(random_dataset): bad_times = np.array([0, 10, 20, 30, 50], dtype=np.float32) with pytest.raises(ValueError): PETModel(dataset=random_dataset, timepoints=bad_times, xlim=60.0) + + +def test_init_start_index_error(): + data = np.ones((1, 1, 1, 3), dtype=float) + dataset = PET(data) + timepoints = np.array([15.0, 45.0, 75.0], dtype=float) + xlim = 100.0 + + # Negative start_index raises ValueError + with pytest.raises(ValueError, match=START_INDEX_RANGE_ERROR_MSG): + PETModel(dataset, timepoints=timepoints, xlim=xlim, start_index=-1) + + # start_index equal to len(timepoints) is out of range + with pytest.raises(ValueError, match=START_INDEX_RANGE_ERROR_MSG): + PETModel(dataset, timepoints=timepoints, xlim=xlim, start_index=len(timepoints)) + + +def test_fit_predict_index_error(): + data = np.ones((1, 1, 1, 3), dtype=float) + dataset = PET(data) + timepoints = np.array([15.0, 45.0, 75.0], dtype=float) + xlim = 100.0 + + model = PETModel( + dataset, + timepoints=timepoints, + xlim=xlim, + smooth_fwhm=0.0, + thresh_pct=0.0, + ) + + model.fit_predict(None) + + # Requesting an negative index should raise IndexError + with pytest.raises(IndexError, match=FIT_INDEX_OUT_OF_RANGE_ERROR_MSG): + model.fit_predict(index=-1) + + # Index equal to len(self._x) should also raise + with pytest.raises(IndexError, match=FIT_INDEX_OUT_OF_RANGE_ERROR_MSG): + model.fit_predict(index=len(timepoints)) + + # Index greater than to len(self._x) should also raise + with pytest.raises(IndexError, match=FIT_INDEX_OUT_OF_RANGE_ERROR_MSG): + model.fit_predict(index=len(timepoints) + 1) + + +def test_petmodel_start_index_reuses_start_prediction(): + # Create a tiny 1-voxel 5-frame sequence with increasing signal + data = np.arange(1.0, 6.0, dtype=float).reshape((1, 1, 1, 5)) + dataset = PET(data) + + # Timepoints in seconds (monotonic) + timepoints = np.array([15.0, 45.0, 75.0, 105.0, 135.0], dtype=float) + xlim = 150.0 + + # Configure the model to start fitting at index=2 (timepoint 75s) + model = PETModel( + dataset, + timepoints=timepoints, + xlim=xlim, + smooth_fwhm=0.0, # disable smoothing for deterministic behaviour + thresh_pct=0.0, # disable thresholding + start_index=2, + ) + + model.fit_predict(None) + + # Prediction for the configured start timepoint + pred_start = model.fit_predict(index=timepoints[2]) + + # Prediction for an earlier timepoint (should reuse start prediction) + pred_early = model.fit_predict(index=timepoints[1]) + + assert np.allclose(pred_start, pred_early), ( + "Earlier frames should reuse start-frame prediction" + ) + + # Prediction for a later timepoint should be allowed and may differ + pred_late = model.fit_predict(index=timepoints[3]) + assert pred_late is not None + + assert pred_start.shape == data.shape[:3] + assert pred_early.shape == data.shape[:3] + assert pred_late.shape == data.shape[:3]