Skip to content

Commit 5f5a977

Browse files
committed
ENH: Add start time attribute to PET model
Add start time attribute to PET model: allows to fit a PET model using the frames starting at `frame_index`.
1 parent e92e725 commit 5f5a977

File tree

2 files changed

+141
-4
lines changed

2 files changed

+141
-4
lines changed

src/nifreeze/model/pet.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@
3838
DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
3939
"""Time frame tolerance in seconds."""
4040

41+
START_INDEX_RANGE_ERROR_MSG = "start_index must be within the range of provided timepoints."
42+
"""PET model fitting start index allowed values error."""
43+
44+
FIT_INDEX_OUT_OF_RANGE_ERROR_MSG = "Index out of range for available timepoints."
45+
"""PET modl fitting index out-of-range error"""
46+
4147

4248
class PETModel(BaseModel):
4349
"""A PET imaging realignment model based on B-Spline approximation."""
@@ -52,6 +58,8 @@ class PETModel(BaseModel):
5258
"_mask",
5359
"_smooth_fwhm",
5460
"_thresh_pct",
61+
"_start_index",
62+
"_start_time",
5563
)
5664

5765
def __init__(
@@ -63,6 +71,7 @@ def __init__(
6371
order: int = 3,
6472
smooth_fwhm: float = 10.0,
6573
thresh_pct: float = 20.0,
74+
start_index: int | None = None,
6675
**kwargs,
6776
):
6877
"""
@@ -80,6 +89,11 @@ def __init__(
8089
six timepoints will be used. The less control points, the smoother is the
8190
model.
8291
92+
start_index : :obj:`int` or None
93+
If provided, the model will be fitted using only timepoints starting from
94+
this index (inclusive). Predictions for timepoints earlier than the
95+
specified start will reuse the predicted volume for the start timepoint.
96+
8397
"""
8498
super().__init__(dataset, **kwargs)
8599

@@ -97,6 +111,15 @@ def __init__(
97111
if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL):
98112
raise ValueError("Last frame midpoint should not be equal or greater than duration")
99113

114+
# Validate and store start index / time
115+
if start_index is None:
116+
self._start_index = 0
117+
else:
118+
if start_index < 0 or start_index >= len(self._x):
119+
raise ValueError(START_INDEX_RANGE_ERROR_MSG)
120+
self._start_index = start_index
121+
self._start_time = float(self._x[self._start_index])
122+
100123
# Calculate index coordinates in the B-Spline grid
101124
self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1
102125

@@ -119,7 +142,9 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
119142
if index is not None:
120143
raise NotImplementedError("Fitting with held-out data is not supported")
121144
timepoints = kwargs.get("timepoints", None) or self._x
122-
x = np.asarray((np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl)
145+
timepoints_to_fit = np.asarray(timepoints, dtype="float32")[self._start_index :]
146+
147+
x = np.asarray((np.array(timepoints_to_fit) / self._xlim) * self._n_ctrl)
123148

124149
data = self._dataset.dataobj
125150
brainmask = self._dataset.brainmask
@@ -137,6 +162,11 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
137162
# Convert data into V (voxels) x T (timepoints)
138163
data = data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask]
139164

165+
# If fitting started later than the first frame, drop early columns so the
166+
# temporal length matches timepoints_to_fit
167+
if self._start_index > 0:
168+
data = data[:, self._start_index :]
169+
140170
# A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding)
141171
A = BSpline.design_matrix(x, self._t, k=self._order)
142172
AT = A.T
@@ -151,7 +181,12 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
151181
return n_jobs
152182

153183
def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]:
154-
"""Return the corrected volume using B-spline interpolation."""
184+
"""Return the corrected volume using B-spline interpolation.
185+
186+
Predictions for times earlier than the configured start_time will return
187+
the prediction for the start_time (i.e., transforms estimated for the
188+
start are reused for earlier low-SNR frames).
189+
"""
155190

156191
# Fit the BSpline basis on all data
157192
if self._locked_fit is None:
@@ -164,8 +199,22 @@ def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, N
164199
if index is None: # If no index, just fit the data.
165200
return None
166201

202+
# Map integer indices to actual timepoints if needed
203+
if isinstance(index, (int, np.integer)):
204+
idx_int = int(index)
205+
if idx_int < 0 or idx_int >= len(self._x):
206+
raise IndexError(FIT_INDEX_OUT_OF_RANGE_ERROR_MSG)
207+
index_time = float(self._x[idx_int])
208+
else:
209+
index_time = float(index)
210+
211+
# If the requested time is earlier than the configured start time, use the
212+
# start time's prediction (reuse the transforms estimated for start)
213+
if index_time < self._start_time:
214+
index_time = self._start_time
215+
167216
# Project sample timing into B-Spline coordinates
168-
x = np.asarray((index / self._xlim) * self._n_ctrl)
217+
x = np.asarray((index_time / self._xlim) * self._n_ctrl)
169218
A = BSpline.design_matrix(x, self._t, k=self._order)
170219

171220
# A is 1 (num. timepoints) x C (num. coeff)

test/test_model_pet.py

Lines changed: 89 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
import pytest
2626

2727
from nifreeze.data.pet import PET
28-
from nifreeze.model.pet import PETModel
28+
from nifreeze.model.pet import (
29+
FIT_INDEX_OUT_OF_RANGE_ERROR_MSG,
30+
START_INDEX_RANGE_ERROR_MSG,
31+
PETModel,
32+
)
2933

3034

3135
@pytest.fixture
@@ -81,3 +85,87 @@ def test_petmodel_time_check(random_dataset):
8185
bad_times = np.array([0, 10, 20, 30, 50], dtype=np.float32)
8286
with pytest.raises(ValueError):
8387
PETModel(dataset=random_dataset, timepoints=bad_times, xlim=60.0)
88+
89+
90+
def test_init_start_index_error():
91+
data = np.ones((1, 1, 1, 3), dtype=float)
92+
dataset = PET(data)
93+
timepoints = np.array([15.0, 45.0, 75.0], dtype=float)
94+
xlim = 100.0
95+
96+
# Negative start_index raises ValueError
97+
with pytest.raises(ValueError, match=START_INDEX_RANGE_ERROR_MSG):
98+
PETModel(dataset, timepoints=timepoints, xlim=xlim, start_index=-1)
99+
100+
# start_index equal to len(timepoints) is out of range
101+
with pytest.raises(ValueError, match=START_INDEX_RANGE_ERROR_MSG):
102+
PETModel(dataset, timepoints=timepoints, xlim=xlim, start_index=len(timepoints))
103+
104+
105+
def test_fit_predict_index_error():
106+
data = np.ones((1, 1, 1, 3), dtype=float)
107+
dataset = PET(data)
108+
timepoints = np.array([15.0, 45.0, 75.0], dtype=float)
109+
xlim = 100.0
110+
111+
model = PETModel(
112+
dataset,
113+
timepoints=timepoints,
114+
xlim=xlim,
115+
smooth_fwhm=0.0,
116+
thresh_pct=0.0,
117+
)
118+
119+
model.fit_predict(None)
120+
121+
# ToDo
122+
# Check this: the code does not check for existence
123+
# Requesting an integer index that does not exist should raise IndexError
124+
with pytest.raises(IndexError, match=FIT_INDEX_OUT_OF_RANGE_ERROR_MSG):
125+
model.fit_predict(index=10)
126+
127+
# ToDo
128+
# Check this
129+
# Index equal to len(self._x) is out of range and should also raise
130+
with pytest.raises(IndexError, match=FIT_INDEX_OUT_OF_RANGE_ERROR_MSG):
131+
model.fit(index=len(x))
132+
133+
134+
def test_petmodel_start_index_reuses_start_prediction():
135+
# Create a tiny 1-voxel 5-frame sequence with increasing signal
136+
data = np.arange(1.0, 6.0, dtype=float).reshape((1, 1, 1, 5))
137+
dataset = PET(data)
138+
139+
# Timepoints in seconds (monotonic)
140+
timepoints = np.array([15.0, 45.0, 75.0, 105.0, 135.0], dtype=float)
141+
xlim = 150.0
142+
143+
# Configure the model to start fitting at index=2 (timepoint 75s)
144+
model = PETModel(
145+
dataset,
146+
timepoints=timepoints,
147+
xlim=xlim,
148+
smooth_fwhm=0.0, # disable smoothing for deterministic behaviour
149+
thresh_pct=0.0, # disable thresholding
150+
start_index=2,
151+
)
152+
153+
model.fit_predict(None)
154+
155+
# Prediction for the configured start timepoint
156+
pred_start = model.fit_predict(index=timepoints[2])
157+
158+
# Prediction for an earlier timepoint (should reuse start prediction)
159+
pred_early = model.fit_predict(index=timepoints[1])
160+
161+
assert np.allclose(pred_start, pred_early), (
162+
"Earlier frames should reuse start-frame prediction"
163+
)
164+
165+
# Prediction for a later timepoint should be allowed and may differ
166+
pred_late = model.fit_predict(index=timepoints[3])
167+
assert pred_late is not None
168+
169+
assert pred_start.shape == data.shape[:3]
170+
assert pred_early.shape == data.shape[:3]
171+
assert pred_late.shape == data.shape[:3]

0 commit comments

Comments
 (0)