Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/nifreeze/model/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@
DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
"""Time frame tolerance in seconds."""

TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG = "timepoints must be provided in initialization."
"""PET model timepoint data missing error message."""
FIRST_MIDPOINT_VALUE_ERROR_MSG = "First frame midpoint should not be zero or negative."
"""PET model first midpoint value error message."""
LAST_MIDPOINT_VALUE_ERROR_MSG = "Last frame midpoint should not be equal or greater than duration."
"""PET model last midpoint value error message."""


class PETModel(BaseModel):
"""A PET imaging realignment model based on B-Spline approximation."""
Expand Down Expand Up @@ -84,7 +91,7 @@ def __init__(
super().__init__(dataset, **kwargs)

if timepoints is None or xlim is None:
raise TypeError("timepoints must be provided in initialization")
raise TypeError(TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG)

self._order = order
self._x = np.array(timepoints, dtype="float32")
Expand All @@ -93,9 +100,9 @@ def __init__(
self._thresh_pct = thresh_pct

if self._x[0] < DEFAULT_TIMEFRAME_MIDPOINT_TOL:
raise ValueError("First frame midpoint should not be zero or negative")
raise ValueError(FIRST_MIDPOINT_VALUE_ERROR_MSG)
if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL):
raise ValueError("Last frame midpoint should not be equal or greater than duration")
raise ValueError(LAST_MIDPOINT_VALUE_ERROR_MSG)

# Calculate index coordinates in the B-Spline grid
self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1
Expand Down
45 changes: 38 additions & 7 deletions test/test_model_pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
import pytest

from nifreeze.data.pet import PET
from nifreeze.model.pet import PETModel
from nifreeze.model.pet import (
DEFAULT_TIMEFRAME_MIDPOINT_TOL,
FIRST_MIDPOINT_VALUE_ERROR_MSG,
LAST_MIDPOINT_VALUE_ERROR_MSG,
TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG,
PETModel,
)


@pytest.fixture
Expand Down Expand Up @@ -71,13 +77,38 @@ def test_petmodel_fit_predict(random_dataset):


@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0)
def test_petmodel_invalid_init(random_dataset):
with pytest.raises(TypeError):
def test_petmodel_init_mandatory_attr_errors(random_dataset):
with pytest.raises(TypeError, match=TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG):
PETModel(dataset=random_dataset)

xlim = 55.0
with pytest.raises(TypeError, match=TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG):
PETModel(dataset=random_dataset, xlim=xlim)

timepoints = np.array([20, 30, 40, 50, 60], dtype=np.float32)
with pytest.raises(TypeError, match=TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG):
PETModel(dataset=random_dataset, timepoints=timepoints)


@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0)
def test_petmodel_time_check(random_dataset):
bad_times = np.array([0, 10, 20, 30, 50], dtype=np.float32)
with pytest.raises(ValueError):
PETModel(dataset=random_dataset, timepoints=bad_times, xlim=60.0)
def test_petmodel_first_midpoint_error(random_dataset):
timepoints = np.array([0, 10, 20, 30, 50], dtype=np.float32)
xlim = 60.0
with pytest.raises(ValueError, match=FIRST_MIDPOINT_VALUE_ERROR_MSG):
PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim)

timepoints[0] = DEFAULT_TIMEFRAME_MIDPOINT_TOL
with pytest.raises(ValueError, match=FIRST_MIDPOINT_VALUE_ERROR_MSG):
PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim)


@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 40.0)
def test_petmodel_last_midpoint_error(random_dataset):
xlim = 45.0
timepoints = np.array([5, 10, 20, 30, 50], dtype=np.float32)
with pytest.raises(ValueError, match=LAST_MIDPOINT_VALUE_ERROR_MSG):
PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim)

timepoints[-1] = xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL
with pytest.raises(ValueError, match=LAST_MIDPOINT_VALUE_ERROR_MSG):
PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim)