Skip to content

Commit 9dacdc4

Browse files
committed
REF: Fix PET model exception test functions
Fix PET model exception test functions: - Refactor functions so that the last midpoint value exception is properly captured: rename tests functions so that their purposes becomes apparent from the name and add a test to check the last midpoint value exception. - Check both `midframe` and `xlim` missing exceptions. - Define the exception messages so that tests can check exactly the captured message.
1 parent e92e725 commit 9dacdc4

File tree

2 files changed

+48
-10
lines changed

2 files changed

+48
-10
lines changed

src/nifreeze/model/pet.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,13 @@
3838
DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2
3939
"""Time frame tolerance in seconds."""
4040

41+
TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG = "timepoints must be provided in initialization."
42+
"""PET model timepoint data missing error message."""
43+
FIRST_MIDPOINT_VALUE_ERROR_MSG = "First frame midpoint should not be zero or negative."
44+
"""PET model first midpoint value error message."""
45+
LAST_MIDPOINT_VALUE_ERROR_MSG = "Last frame midpoint should not be equal or greater than duration."
46+
"""PET model last midpoint value error message."""
47+
4148

4249
class PETModel(BaseModel):
4350
"""A PET imaging realignment model based on B-Spline approximation."""
@@ -84,7 +91,7 @@ def __init__(
8491
super().__init__(dataset, **kwargs)
8592

8693
if timepoints is None or xlim is None:
87-
raise TypeError("timepoints must be provided in initialization")
94+
raise TypeError(TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG)
8895

8996
self._order = order
9097
self._x = np.array(timepoints, dtype="float32")
@@ -93,9 +100,9 @@ def __init__(
93100
self._thresh_pct = thresh_pct
94101

95102
if self._x[0] < DEFAULT_TIMEFRAME_MIDPOINT_TOL:
96-
raise ValueError("First frame midpoint should not be zero or negative")
103+
raise ValueError(FIRST_MIDPOINT_VALUE_ERROR_MSG)
97104
if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL):
98-
raise ValueError("Last frame midpoint should not be equal or greater than duration")
105+
raise ValueError(LAST_MIDPOINT_VALUE_ERROR_MSG)
99106

100107
# Calculate index coordinates in the B-Spline grid
101108
self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1

test/test_model_pet.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525
import pytest
2626

2727
from nifreeze.data.pet import PET
28-
from nifreeze.model.pet import PETModel
28+
from nifreeze.model.pet import (
29+
DEFAULT_TIMEFRAME_MIDPOINT_TOL,
30+
FIRST_MIDPOINT_VALUE_ERROR_MSG,
31+
LAST_MIDPOINT_VALUE_ERROR_MSG,
32+
TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG,
33+
PETModel,
34+
)
2935

3036

3137
@pytest.fixture
@@ -71,13 +77,38 @@ def test_petmodel_fit_predict(random_dataset):
7177

7278

7379
@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0)
74-
def test_petmodel_invalid_init(random_dataset):
75-
with pytest.raises(TypeError):
80+
def test_petmodel_init_mandatory_attr_errors(random_dataset):
81+
with pytest.raises(TypeError, match=TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG):
7682
PETModel(dataset=random_dataset)
7783

84+
xlim = 55.0
85+
with pytest.raises(TypeError, match=TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG):
86+
PETModel(dataset=random_dataset, xlim=xlim)
87+
88+
timepoints = np.array([20, 30, 40, 50, 60], dtype=np.float32)
89+
with pytest.raises(TypeError, match=TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG):
90+
PETModel(dataset=random_dataset, timepoints=timepoints)
91+
7892

7993
@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0)
80-
def test_petmodel_time_check(random_dataset):
81-
bad_times = np.array([0, 10, 20, 30, 50], dtype=np.float32)
82-
with pytest.raises(ValueError):
83-
PETModel(dataset=random_dataset, timepoints=bad_times, xlim=60.0)
94+
def test_petmodel_first_midpoint_error(random_dataset):
95+
timepoints = np.array([0, 10, 20, 30, 50], dtype=np.float32)
96+
xlim = 60.0
97+
with pytest.raises(ValueError, match=FIRST_MIDPOINT_VALUE_ERROR_MSG):
98+
PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim)
99+
100+
timepoints[0] = DEFAULT_TIMEFRAME_MIDPOINT_TOL
101+
with pytest.raises(ValueError, match=FIRST_MIDPOINT_VALUE_ERROR_MSG):
102+
PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim)
103+
104+
105+
@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 40.0)
106+
def test_petmodel_last_midpoint_error(random_dataset):
107+
xlim = 45.0
108+
timepoints = np.array([5, 10, 20, 30, 50], dtype=np.float32)
109+
with pytest.raises(ValueError, match=LAST_MIDPOINT_VALUE_ERROR_MSG):
110+
PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim)
111+
112+
timepoints[-1] = xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL
113+
with pytest.raises(ValueError, match=LAST_MIDPOINT_VALUE_ERROR_MSG):
114+
PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim)

0 commit comments

Comments
 (0)