Skip to content

Commit 77dc381

Browse files
authored
Merge pull request #298 from jhlegarreta/tst/test-expectation-model
TST: Test the `Expectation` model
2 parents 378df62 + 5306ed1 commit 77dc381

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

test/test_model.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,57 @@ def test_trivial_model(request, use_mask):
112112
assert np.all(_clipped_S0 == predicted)
113113

114114

115+
def test_expectation_model(request):
116+
class DummySequenceDataset:
117+
def __init__(self, data, brainmask):
118+
# data_4d shape is (x,y,z,t)
119+
self.data = data
120+
self.brainmask = brainmask
121+
122+
def __len__(self):
123+
# pretend T timepoints
124+
return self.data.shape[-1]
125+
126+
def __getitem__(self, index):
127+
# When index is boolean mask, emulate the original dataset behavior:
128+
# return a tuple whose first element is the 4D data subset
129+
if isinstance(index, (list, tuple, np.ndarray)):
130+
# Boolean indexing along time axis
131+
sel = np.asarray(index, dtype=bool)
132+
# Create subset along last axis and return as first element in tuple
133+
return (self.data[..., sel],)
134+
# Other cases: forward slice/index to the timepoint
135+
return (self.data[..., index],)
136+
137+
# Create a dataset with a single voxel and 4 timepoints
138+
vals = np.array([1.0, 2.0, 3.0, 4.0], dtype=float)
139+
_data = vals.reshape((1, 1, 1, -1))
140+
_brainmask = request.node.rng.choice([True, False], size=_data.shape[:3])
141+
dataset = DummySequenceDataset(_data, _brainmask)
142+
143+
stat = "mean"
144+
avg_func = getattr(np, stat)
145+
em_model = model.ExpectationModel(dataset, stat=stat)
146+
147+
# Calling with index specified should exclude that index and return the
148+
# immediate value
149+
# exclude index 1 => use timepoints 0,2,3 -> mean of [1,3,4] = 8/3
150+
_index = 1
151+
index_mask = np.ones(len(dataset), dtype=bool)
152+
index_mask[_index] = False
153+
pred = em_model.fit_predict(index=1)
154+
assert np.allclose(pred, avg_func(dataset[index_mask][0], axis=-1))
155+
156+
# First call with index=None should compute and lock the fit
157+
pred = em_model.fit_predict(index=None)
158+
assert em_model._locked_fit is not None
159+
assert np.allclose(pred, em_model._locked_fit)
160+
assert np.allclose(pred, avg_func(_data, axis=-1))
161+
# Calling again returns the locked fit
162+
pred2 = em_model.fit_predict(index=None)
163+
assert pred2 is pred
164+
165+
115166
def test_average_model():
116167
"""Check the implementation of the average DW model."""
117168

0 commit comments

Comments
 (0)