Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/notebooks/pet_motion_estimation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -406,9 +406,9 @@
}
],
"source": [
"from nifreeze.model import PETModel\n",
"from nifreeze.model import BSplinePETModel\n",
"\n",
"model = PETModel(dataset=pet_dataset, timepoints=pet_dataset.midframe, xlim=7000)"
"model = BSplinePETModel(dataset=pet_dataset, timepoints=pet_dataset.midframe, xlim=7000)"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions src/nifreeze/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from nifreeze.data.base import BaseDataset
from nifreeze.data.pet import PET
from nifreeze.model.base import BaseModel, ModelFactory
from nifreeze.model.pet import PETModel
from nifreeze.model.pet import BSplinePETModel
from nifreeze.registration.ants import (
Registration,
_prepare_registration_data,
Expand Down Expand Up @@ -261,8 +261,8 @@ def run(self, pet_dataset: PET, omp_nthreads: int | None = None) -> list:
total_duration=pet_dataset.total_duration,
)

# Instantiate PETModel explicitly
model = PETModel(
# Instantiate the PET model explicitly
model = BSplinePETModel(
dataset=train_dataset,
timepoints=train_times,
xlim=pet_dataset.total_duration,
Expand Down
4 changes: 2 additions & 2 deletions src/nifreeze/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
DTIModel,
GPModel,
)
from nifreeze.model.pet import PETModel
from nifreeze.model.pet import BSplinePETModel

__all__ = (
"ModelFactory",
Expand All @@ -43,5 +43,5 @@
"DTIModel",
"GPModel",
"TrivialModel",
"PETModel",
"BSplinePETModel",
)
185 changes: 130 additions & 55 deletions src/nifreeze/model/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#
"""Models for nuclear imaging."""

from abc import ABC, ABCMeta, abstractmethod
from os import cpu_count
from typing import Union

Expand All @@ -39,56 +40,78 @@
"""Time frame tolerance in seconds."""


class PETModel(BaseModel):
"""A PET imaging realignment model based on B-Spline approximation."""
def _exec_fit(model, data, chunk=None, **kwargs):
return model.fit(data, **kwargs), chunk

__slots__ = (
"_t",
"_x",
"_xlim",
"_order",
"_n_ctrl",
"_datashape",
"_mask",
"_smooth_fwhm",
"_thresh_pct",
)

def _exec_predict(model, chunk=None, **kwargs):
"""Propagate model parameters and call predict."""
return np.squeeze(model.predict(**kwargs)), chunk


class BasePETModel(BaseModel, ABC):
"""Interface and default methods for PET models."""

__metaclass__ = ABCMeta

__slots__ = {
"_data_mask": "A mask for the voxels that will be fitted and predicted",
"_x": "",
"_xlim": "",
"_smooth_fwhm": "FWHM in mm over which to smooth",
"_thresh_pct": "Thresholding percentile for the signal",
"_model_class": "Defining a model class",
"_modelargs": "Arguments acceptable by the underlying model",
"_models": "List with one or more (if parallel execution) model instances",
}

def __init__(
self,
dataset: PET,
timepoints: list | np.ndarray | None = None,
xlim: float | None = None,
n_ctrl: int | None = None,
order: int = 3,
xlim: list | np.ndarray | None = None,
smooth_fwhm: float = 10.0,
thresh_pct: float = 20.0,
**kwargs,
):
"""
Create the B-Spline interpolating matrix.
"""Initialization.

Parameters:
-----------
timepoints : :obj:`list`
Parameters
----------
timepoints : :obj:`list` or :obj:`~np.ndarray`
The timing (in sec) of each PET volume.
E.g., ``[15., 45., 75., 105., 135., 165., 210., 270., 330.,
420., 540., 750., 1050., 1350., 1650., 1950., 2250., 2550.]``

n_ctrl : :obj:`int`
Number of B-Spline control points. If `None`, then one control point every
six timepoints will be used. The less control points, the smoother is the
model.

xlim : .
.
smooth_fwhm : obj:`float`
FWHM in mm over which to smooth the signal.
thresh_pct : obj:`float`
Thresholding percentile for the signal.
"""

super().__init__(dataset, **kwargs)

# Duck typing, instead of explicitly testing for PET type
if not hasattr(dataset, "total_duration"):
raise TypeError("Dataset MUST be a PET object.")

if not hasattr(dataset, "midframe"):
raise ValueError("Dataset MUST have a midframe.")

self._data_mask = (
dataset.brainmask
if dataset.brainmask is not None
else np.ones(dataset.dataobj.shape[:3], dtype=bool)
)

# ToDo
# Are timepoints and xlim features that all PET models require ??
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mnoergaard question for you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say so, since we are dealing with motion correction, and hence would need temporal information. Currently, the PET model needs both the sampling midpoints and the scan’s total duration to work, so enforcing their presence in the base class keeps the API consistent. When new PET models that do not need temporal information are introduced, we can revisit this requirement; until then, providing dataset.midframe and dataset.total_duration (as done in the unit tests) is the intended usage.

Copy link
Contributor Author

@jhlegarreta jhlegarreta Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see that this is related to #204 (review). I think the naming should then be made consistent across the PET data class and the model class. Also, I do not see why the model is given the timepoints, if these are hold by the PET data class, i.e. from #204 (review)

(...) midframe is where the dataset stores real-world frame timing; timepoints/_x is the copy of those timings handed to the model

then we should not be providing the model with a copy of them.

If we make the parallel with the DWI class, the gradients are obtained from the dataset, e.g.:

gradient = self._dataset.gradients[:, index]

Also, can the xlim name be made somehow more descriptive or is it a name that is commonly used within the PET domain?

if timepoints is None or xlim is None:
raise TypeError("timepoints must be provided in initialization")
raise ValueError("`timepoints` and `xlim` must be specified and have a nonzero value.")

self._order = order
self._x = np.array(timepoints, dtype="float32")
self._xlim = xlim
self._x = np.asarray(timepoints, dtype="float32")
self._xlim = np.asarray(xlim)
self._smooth_fwhm = smooth_fwhm
self._thresh_pct = thresh_pct

Expand All @@ -97,62 +120,114 @@ def __init__(
if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL):
raise ValueError("Last frame midpoint should not be equal or greater than duration")

# Calculate index coordinates in the B-Spline grid
self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1
def _preprocess_data(self) -> np.ndarray:
# ToDo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the todo here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, the idea there is to preprocess only the data that is required; e.g. if we are using only a subset of the volumes, then we shouldn't preprocess the entire data, e.g.

data, _, gtab = self._dataset[idxmask]
# Select voxels within mask or just unravel 3D if no mask
data = data[brainmask, ...] if brainmask is not None else data.reshape(-1, data.shape[-1])

The data, _, gtab = self._dataset[idxmask] is commented out in here immediately after the ToDo to as a pointer to that idea.

# data, _, gtab = self._dataset[idxmask] ### This needs the PET data model to be changed
data = self._dataset.dataobj
brainmask = self._dataset.brainmask

# B-Spline knots
self._t = np.arange(-3, float(self._n_ctrl) + 4, dtype="float32")
# Preprocess the data
if self._smooth_fwhm > 0:
smoothed_img = smooth_image(
nb.Nifti1Image(data, self._dataset.affine), self._smooth_fwhm
)
data = smoothed_img.get_fdata()

self._datashape = None
self._mask = None
if self._thresh_pct > 0:
thresh_val = np.percentile(data, self._thresh_pct)
data[data < thresh_val] = 0

# Convert data into V (voxels) x T (timepoints)
return data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask]

@property
def is_fitted(self) -> bool:
return self._locked_fit is not None

@abstractmethod
def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]:
"""Predict the corrected volume."""
return None


class BSplinePETModel(BasePETModel):
"""A PET imaging realignment model based on B-Spline approximation."""

__slots__ = (
"_t",
"_order",
"_n_ctrl",
)

def __init__(
self,
dataset: PET,
n_ctrl: int | None = None,
order: int = 3,
**kwargs,
):
"""Create the B-Spline interpolating matrix.

Parameters
----------
n_ctrl : :obj:`int`
Number of B-Spline control points. If `None`, then one control point every
six timepoints will be used. The less control points, the smoother is the
model.
order : :obj:`int`
Order of the B-Spline approximation.
"""

super().__init__(dataset, **kwargs)

self._order = order

# Calculate index coordinates in the B-Spline grid
self._n_ctrl = n_ctrl or (len(self._x) // 4) + 1

# B-Spline knots
self._t = np.arange(-3, self._n_ctrl + 4, dtype="float32")

def _fit(self, index: int | None = None, n_jobs=None, **kwargs) -> int:
"""Fit the model."""

n_jobs = n_jobs or min(cpu_count() or 1, 8)

if self._locked_fit is not None:
return n_jobs

if index is not None:
raise NotImplementedError("Fitting with held-out data is not supported")
timepoints = kwargs.get("timepoints", None) or self._x
x = np.asarray((np.array(timepoints, dtype="float32") / self._xlim) * self._n_ctrl)

data = self._dataset.dataobj
brainmask = self._dataset.brainmask

if self._smooth_fwhm > 0:
smoothed_img = smooth_image(
nb.Nifti1Image(data, self._dataset.affine), self._smooth_fwhm
)
data = smoothed_img.get_fdata()

if self._thresh_pct > 0:
thresh_val = np.percentile(data, self._thresh_pct)
data[data < thresh_val] = 0
data = self._preprocess_data()

# Convert data into V (voxels) x T (timepoints)
data = data.reshape((-1, data.shape[-1])) if brainmask is None else data[brainmask]
# ToDo
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allowing _fit to override timepoints through kwargs duplicates information that was already validated and stored on self._x during initialization. Dropping that extra kwarg simplifies the API and avoids inconsistent inputs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I understand: maybe the reply is answering a question elsewhere?

# Does not make sense to make timepoints be a kwarg if it is provided as a named parameter to __init__
timepoints = kwargs.get("timepoints", None) or self._x
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be removed

x = np.asarray(timepoints, dtype="float32") / self._xlim * self._n_ctrl
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be replaced with

x = self._x / self._xlim * self._n_ctrl


# A.shape = (T, K - 4); T= n. timepoints, K= n. knots (with padding)
A = BSpline.design_matrix(x, self._t, k=self._order)
AT = A.T
ATdotA = AT @ A

# Parallelize process with joblib
with Parallel(n_jobs=n_jobs or min(cpu_count() or 1, 8)) as executor:
with Parallel(n_jobs=n_jobs) as executor:
results = executor(delayed(cg)(ATdotA, AT @ v) for v in data)

self._locked_fit = np.array([r[0] for r in results])
self._locked_fit = np.asarray([r[0] for r in results])

return n_jobs

def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, None]:
"""Return the corrected volume using B-spline interpolation."""

# ToDo
# Does the below apply to PET ? Martin has the return None statement
# if index is None:
# raise RuntimeError(
# f"Model {self.__class__.__name__} does not allow locking.")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another question for @mnoergaard.


# Fit the BSpline basis on all data
if self._locked_fit is None:
self._fit(index, n_jobs=kwargs.pop("n_jobs", None), **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion test/test_integration_pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def fit_predict(self, index):
return None
return np.zeros(self.dataset.shape3d, dtype=np.float32)

monkeypatch.setattr("nifreeze.estimator.PETModel", DummyModel)
monkeypatch.setattr("nifreeze.estimator.BSplinePETModel", DummyModel)

class DummyRegistration:
def __init__(self, *args, **kwargs):
Expand Down
20 changes: 15 additions & 5 deletions test/test_model_pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pytest

from nifreeze.data.pet import PET
from nifreeze.model.pet import PETModel
from nifreeze.model.pet import BSplinePETModel


@pytest.fixture
Expand All @@ -49,9 +49,19 @@ def random_dataset(setup_random_pet_data) -> PET:
)


def test_pet_base_model():
from nifreeze.model.pet import BasePETModel

with pytest.raises(
TypeError,
match="Can't instantiate abstract class BasePETModel without an implementation for abstract method 'fit_predict'",
):
BasePETModel(None)


@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0)
def test_petmodel_fit_predict(random_dataset):
model = PETModel(
model = BSplinePETModel(
dataset=random_dataset,
timepoints=random_dataset.midframe,
xlim=random_dataset.total_duration,
Expand All @@ -72,12 +82,12 @@ def test_petmodel_fit_predict(random_dataset):

@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0)
def test_petmodel_invalid_init(random_dataset):
with pytest.raises(TypeError):
PETModel(dataset=random_dataset)
with pytest.raises(ValueError):
BSplinePETModel(dataset=random_dataset)


@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0)
def test_petmodel_time_check(random_dataset):
bad_times = np.array([0, 10, 20, 30, 50], dtype=np.float32)
with pytest.raises(ValueError):
PETModel(dataset=random_dataset, timepoints=bad_times, xlim=60.0)
BSplinePETModel(dataset=random_dataset, timepoints=bad_times, xlim=60.0)
Loading