Skip to content

Commit d27ba75

Browse files
authored
Merge pull request #300 from jhlegarreta/tst/fix-trivial-model-test-typing
STY: Fix `Trivial` model test typing
2 parents adf71da + b17040a commit d27ba75

File tree

2 files changed

+20
-7
lines changed

2 files changed

+20
-7
lines changed

src/nifreeze/model/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@
2727

2828
import numpy as np
2929

30-
mask_absence_warn_msg = (
30+
MASK_ABSENCE_WARN_MSG = (
3131
"No mask provided; consider using a mask to avoid issues in model optimization."
3232
)
33+
"""Mask warning message."""
34+
PREDICTED_MAP_ERROR_MSG = "This model requires the predicted map at initialization"
35+
"""Oracle requirement error message."""
3336
UNSUPPORTED_MODEL_ERROR_MSG = "Unsupported model <{model}>."
3437
"""Unsupported model error message"""
3538

@@ -102,7 +105,7 @@ def __init__(self, dataset, **kwargs):
102105
self._dataset = dataset
103106
# Warn if mask not present
104107
if dataset.brainmask is None:
105-
warn(mask_absence_warn_msg, stacklevel=2)
108+
warn(MASK_ABSENCE_WARN_MSG, stacklevel=2)
106109

107110
@abstractmethod
108111
def fit_predict(self, index: int | None = None, **kwargs) -> np.ndarray | None:
@@ -139,7 +142,7 @@ def __init__(self, dataset, predicted=None, **kwargs):
139142
)
140143

141144
if self._locked_fit is None:
142-
raise TypeError("This model requires the predicted map at initialization")
145+
raise TypeError(PREDICTED_MAP_ERROR_MSG)
143146

144147
def fit_predict(self, *_, **kwargs) -> np.ndarray | None:
145148
"""Return the reference map."""

test/test_model.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939
DWI,
4040
)
4141
from nifreeze.model._dipy import GaussianProcessModel
42-
from nifreeze.model.base import UNSUPPORTED_MODEL_ERROR_MSG, mask_absence_warn_msg
42+
from nifreeze.model.base import (
43+
MASK_ABSENCE_WARN_MSG,
44+
PREDICTED_MAP_ERROR_MSG,
45+
UNSUPPORTED_MODEL_ERROR_MSG,
46+
)
4347
from nifreeze.testing import simulations as _sim
4448

4549

@@ -60,6 +64,12 @@ class DummyDataset:
6064
pass
6165

6266

67+
class DummyDatasetNoRef:
68+
def __init__(self):
69+
# No reference or bzero here to trigger TrivialModel error
70+
self.brainmask = np.ones((1, 1, 1, 1)).astype(bool)
71+
72+
6373
def test_base_model():
6474
from nifreeze.model.base import BaseModel
6575

@@ -78,8 +88,8 @@ def test_trivial_model(request, use_mask):
7888
rng = request.node.rng
7989

8090
# Should not allow initialization without an oracle
81-
with pytest.raises(TypeError):
82-
model.TrivialModel() # type: ignore[call-arg]
91+
with pytest.raises(TypeError, match=PREDICTED_MAP_ERROR_MSG):
92+
model.TrivialModel(DummyDatasetNoRef())
8393

8494
size = (2, 2, 2)
8595
mask = None
@@ -88,7 +98,7 @@ def test_trivial_model(request, use_mask):
8898
mask = np.ones(size, dtype=bool)
8999
context = contextlib.nullcontext()
90100
else:
91-
context = pytest.warns(UserWarning, match=mask_absence_warn_msg)
101+
context = pytest.warns(UserWarning, match=MASK_ABSENCE_WARN_MSG)
92102

93103
_S0 = rng.normal(size=size)
94104

0 commit comments

Comments
 (0)