diff --git a/src/nifreeze/model/pet.py b/src/nifreeze/model/pet.py index 9428bdb91..23fdeb587 100644 --- a/src/nifreeze/model/pet.py +++ b/src/nifreeze/model/pet.py @@ -38,6 +38,13 @@ DEFAULT_TIMEFRAME_MIDPOINT_TOL = 1e-2 """Time frame tolerance in seconds.""" +TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG = "timepoints must be provided in initialization." +"""PET model timepoint data missing error message.""" +FIRST_MIDPOINT_VALUE_ERROR_MSG = "First frame midpoint should not be zero or negative." +"""PET model first midpoint value error message.""" +LAST_MIDPOINT_VALUE_ERROR_MSG = "Last frame midpoint should not be equal or greater than duration." +"""PET model last midpoint value error message.""" + class PETModel(BaseModel): """A PET imaging realignment model based on B-Spline approximation.""" @@ -84,7 +91,7 @@ def __init__( super().__init__(dataset, **kwargs) if timepoints is None or xlim is None: - raise TypeError("timepoints must be provided in initialization") + raise TypeError(TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG) self._order = order self._x = np.array(timepoints, dtype="float32") @@ -93,9 +100,9 @@ def __init__( self._thresh_pct = thresh_pct if self._x[0] < DEFAULT_TIMEFRAME_MIDPOINT_TOL: - raise ValueError("First frame midpoint should not be zero or negative") + raise ValueError(FIRST_MIDPOINT_VALUE_ERROR_MSG) if self._x[-1] > (self._xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL): - raise ValueError("Last frame midpoint should not be equal or greater than duration") + raise ValueError(LAST_MIDPOINT_VALUE_ERROR_MSG) # Calculate index coordinates in the B-Spline grid self._n_ctrl = n_ctrl or (len(timepoints) // 4) + 1 diff --git a/test/test_model_pet.py b/test/test_model_pet.py index 6081a07d5..2c373bed2 100644 --- a/test/test_model_pet.py +++ b/test/test_model_pet.py @@ -25,7 +25,13 @@ import pytest from nifreeze.data.pet import PET -from nifreeze.model.pet import PETModel +from nifreeze.model.pet import ( + DEFAULT_TIMEFRAME_MIDPOINT_TOL, + FIRST_MIDPOINT_VALUE_ERROR_MSG, + LAST_MIDPOINT_VALUE_ERROR_MSG, + TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG, + PETModel, +) @pytest.fixture @@ -71,13 +77,38 @@ def test_petmodel_fit_predict(random_dataset): @pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_petmodel_invalid_init(random_dataset): - with pytest.raises(TypeError): +def test_petmodel_init_mandatory_attr_errors(random_dataset): + with pytest.raises(TypeError, match=TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG): PETModel(dataset=random_dataset) + xlim = 55.0 + with pytest.raises(TypeError, match=TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG): + PETModel(dataset=random_dataset, xlim=xlim) + + timepoints = np.array([20, 30, 40, 50, 60], dtype=np.float32) + with pytest.raises(TypeError, match=TIMEPOINT_XLIM_DATA_MISSING_ERROR_MSG): + PETModel(dataset=random_dataset, timepoints=timepoints) + @pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) -def test_petmodel_time_check(random_dataset): - bad_times = np.array([0, 10, 20, 30, 50], dtype=np.float32) - with pytest.raises(ValueError): - PETModel(dataset=random_dataset, timepoints=bad_times, xlim=60.0) +def test_petmodel_first_midpoint_error(random_dataset): + timepoints = np.array([0, 10, 20, 30, 50], dtype=np.float32) + xlim = 60.0 + with pytest.raises(ValueError, match=FIRST_MIDPOINT_VALUE_ERROR_MSG): + PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim) + + timepoints[0] = DEFAULT_TIMEFRAME_MIDPOINT_TOL + with pytest.raises(ValueError, match=FIRST_MIDPOINT_VALUE_ERROR_MSG): + PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim) + + +@pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 40.0) +def test_petmodel_last_midpoint_error(random_dataset): + xlim = 45.0 + timepoints = np.array([5, 10, 20, 30, 50], dtype=np.float32) + with pytest.raises(ValueError, match=LAST_MIDPOINT_VALUE_ERROR_MSG): + PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim) + + timepoints[-1] = xlim - DEFAULT_TIMEFRAME_MIDPOINT_TOL + with pytest.raises(ValueError, match=LAST_MIDPOINT_VALUE_ERROR_MSG): + PETModel(dataset=random_dataset, timepoints=timepoints, xlim=xlim)