Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions src/nifreeze/model/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
"""Time frame tolerance in seconds."""

START_INDEX_RANGE_ERROR_MSG = "start_index must be within the range of provided timepoints."
"""PET model fitting start index allowed values error."""

FIT_INDEX_OUT_OF_RANGE_ERROR_MSG = "Index out of range for available timepoints."
"""PET model fitting index out-of-range error"""


class PETModel(BaseModel):
"""A PET imaging realignment model based on B-Spline approximation."""
Expand All @@ -52,6 +58,8 @@ class PETModel(BaseModel):
"_mask",
"_smooth_fwhm",
"_thresh_pct",
"_start_index",
"_start_time",
)

def __init__(
Expand All @@ -63,6 +71,7 @@ def __init__(
order: int = 3,
smooth_fwhm: float = 10.0,
thresh_pct: float = 20.0,
start_index: int | None = None,
**kwargs,
):
"""
Expand All @@ -80,6 +89,14 @@ def __init__(
six timepoints will be used. The less control points, the smoother is the
model.

start_index : :obj:`int` or None
If provided, the model will be fitted using only timepoints starting from
this index (inclusive). Predictions for timepoints earlier than the
specified start will reuse the predicted volume for the start timepoint.
This is useful, for example, to discard a number of frames at the
beginning of the sequence, which due to their little SNR may impact
registration negatively.

"""
super().__init__(dataset, **kwargs)

Expand All @@ -97,6 +114,15 @@ def __init__(
if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL):
raise ValueError("Last frame midpoint should not be equal or greater than duration")

# Validate and store start index / time
if start_index is None:
self._start_index = 0
else:
if start_index < 0 or start_index >= len(self._x):
raise ValueError(START_INDEX_RANGE_ERROR_MSG)
self._start_index = start_index
self._start_time = float(self._x[self._start_index])

# Calculate index coordinates in the B-Spline grid
self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1

Expand All @@ -119,7 +145,9 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
if index is not None:
raise NotImplementedError("Fitting with held-out data is not supported")
timepoints = kwargs.get("timepoints", None) or self._x
x = np.asarray((np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl)
timepoints_to_fit = np.asarray(timepoints, dtype="float32")[self._start_index :]

x = np.asarray((np.array(timepoints_to_fit) / self._xlim) * self._n_ctrl)

data = self._dataset.dataobj
brainmask = self._dataset.brainmask
Expand All @@ -137,6 +165,11 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
# Convert data into V (voxels) x T (timepoints)
data = data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask]

# If fitting started later than the first frame, drop early columns so the
# temporal length matches timepoints_to_fit
if self._start_index > 0:
data = data[:, self._start_index :]

# A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding)
A = BSpline.design_matrix(x, self._t, k=self._order)
AT = A.T
Expand All @@ -151,7 +184,12 @@ def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
return n_jobs

def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]:
"""Return the corrected volume using B-spline interpolation."""
"""Return the corrected volume using B-spline interpolation.

Predictions for times earlier than the configured start_time will return
the prediction for the start_time (i.e., transforms estimated for the
start are reused for earlier low-SNR frames).
"""

# Fit the BSpline basis on all data
if self._locked_fit is None:
Expand All @@ -164,8 +202,22 @@ def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, N
if index is None: # If no index, just fit the data.
return None

# Map integer indices to actual timepoints if needed
if isinstance(index, (int, np.integer)):
idx_int = int(index)
if idx_int < 0 or idx_int >= len(self._x):
raise IndexError(FIT_INDEX_OUT_OF_RANGE_ERROR_MSG)
index_time = float(self._x[idx_int])
else:
index_time = float(index)

# If the requested time is earlier than the configured start time, use the
# start time's prediction (reuse the transforms estimated for start)
if index_time < self._start_time:
index_time = self._start_time

# Project sample timing into B-Spline coordinates
x = np.asarray((index / self._xlim) * self._n_ctrl)
x = np.asarray((index_time / self._xlim) * self._n_ctrl)
A = BSpline.design_matrix(x, self._t, k=self._order)

# A is 1 (num. timepoints) x C (num. coeff)
Expand Down
90 changes: 89 additions & 1 deletion test/test_model_pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import pytest

from nifreeze.data.pet import PET
from nifreeze.model.pet import PETModel
from nifreeze.model.pet import (
FIT_INDEX_OUT_OF_RANGE_ERROR_MSG,
START_INDEX_RANGE_ERROR_MSG,
PETModel,
)


@pytest.fixture
Expand Down Expand Up @@ -81,3 +85,87 @@ def test_petmodel_time_check(random_dataset):
bad_times = np.array([0, 10, 20, 30, 50], dtype=np.float32)
with pytest.raises(ValueError):
PETModel(dataset=random_dataset, timepoints=bad_times, xlim=60.0)


def test_init_start_index_error():
data = np.ones((1, 1, 1, 3), dtype=float)
dataset = PET(data)
timepoints = np.array([15.0, 45.0, 75.0], dtype=float)
xlim = 100.0

# Negative start_index raises ValueError
with pytest.raises(ValueError, match=START_INDEX_RANGE_ERROR_MSG):
PETModel(dataset, timepoints=timepoints, xlim=xlim, start_index=-1)

# start_index equal to len(timepoints) is out of range
with pytest.raises(ValueError, match=START_INDEX_RANGE_ERROR_MSG):
PETModel(dataset, timepoints=timepoints, xlim=xlim, start_index=len(timepoints))


def test_fit_predict_index_error():
data = np.ones((1, 1, 1, 3), dtype=float)
dataset = PET(data)
timepoints = np.array([15.0, 45.0, 75.0], dtype=float)
xlim = 100.0

model = PETModel(
dataset,
timepoints=timepoints,
xlim=xlim,
smooth_fwhm=0.0,
thresh_pct=0.0,
)

model.fit_predict(None)

# Requesting an negative index should raise IndexError
with pytest.raises(IndexError, match=FIT_INDEX_OUT_OF_RANGE_ERROR_MSG):
model.fit_predict(index=-1)

# Index equal to len(self._x) should also raise
with pytest.raises(IndexError, match=FIT_INDEX_OUT_OF_RANGE_ERROR_MSG):
model.fit_predict(index=len(timepoints))

# Index greater than to len(self._x) should also raise
with pytest.raises(IndexError, match=FIT_INDEX_OUT_OF_RANGE_ERROR_MSG):
model.fit_predict(index=len(timepoints) + 1)


def test_petmodel_start_index_reuses_start_prediction():
# Create a tiny 1-voxel 5-frame sequence with increasing signal
data = np.arange(1.0, 6.0, dtype=float).reshape((1, 1, 1, 5))
dataset = PET(data)

# Timepoints in seconds (monotonic)
timepoints = np.array([15.0, 45.0, 75.0, 105.0, 135.0], dtype=float)
xlim = 150.0

# Configure the model to start fitting at index=2 (timepoint 75s)
model = PETModel(
dataset,
timepoints=timepoints,
xlim=xlim,
smooth_fwhm=0.0, # disable smoothing for deterministic behaviour
thresh_pct=0.0, # disable thresholding
start_index=2,
)

model.fit_predict(None)

# Prediction for the configured start timepoint
pred_start = model.fit_predict(index=timepoints[2])

# Prediction for an earlier timepoint (should reuse start prediction)
pred_early = model.fit_predict(index=timepoints[1])

assert np.allclose(pred_start, pred_early), (
"Earlier frames should reuse start-frame prediction"
)

# Prediction for a later timepoint should be allowed and may differ
pred_late = model.fit_predict(index=timepoints[3])
assert pred_late is not None

assert pred_start.shape == data.shape[:3]
assert pred_early.shape == data.shape[:3]
assert pred_late.shape == data.shape[:3]
Loading