@@ -112,6 +112,57 @@ def test_trivial_model(request, use_mask):
112112 assert np .all (_clipped_S0 == predicted )
113113
114114
115+ def test_expectation_model (request ):
116+ class DummySequenceDataset :
117+ def __init__ (self , data , brainmask ):
118+ # data_4d shape is (x,y,z,t)
119+ self .data = data
120+ self .brainmask = brainmask
121+
122+ def __len__ (self ):
123+ # pretend T timepoints
124+ return self .data .shape [- 1 ]
125+
126+ def __getitem__ (self , index ):
127+ # When index is boolean mask, emulate the original dataset behavior:
128+ # return a tuple whose first element is the 4D data subset
129+ if isinstance (index , (list , tuple , np .ndarray )):
130+ # Boolean indexing along time axis
131+ sel = np .asarray (index , dtype = bool )
132+ # Create subset along last axis and return as first element in tuple
133+ return (self .data [..., sel ],)
134+ # Other cases: forward slice/index to the timepoint
135+ return (self .data [..., index ],)
136+
137+ # Create a dataset with a single voxel and 4 timepoints
138+ vals = np .array ([1.0 , 2.0 , 3.0 , 4.0 ], dtype = float )
139+ _data = vals .reshape ((1 , 1 , 1 , - 1 ))
140+ _brainmask = request .node .rng .choice ([True , False ], size = _data .shape [:3 ])
141+ dataset = DummySequenceDataset (_data , _brainmask )
142+
143+ stat = "mean"
144+ avg_func = getattr (np , stat )
145+ em_model = model .ExpectationModel (dataset , stat = stat )
146+
147+ # Calling with index specified should exclude that index and return the
148+ # immediate value
149+ # exclude index 1 => use timepoints 0,2,3 -> mean of [1,3,4] = 8/3
150+ _index = 1
151+ index_mask = np .ones (len (dataset ), dtype = bool )
152+ index_mask [_index ] = False
153+ pred = em_model .fit_predict (index = 1 )
154+ assert np .allclose (pred , avg_func (dataset [index_mask ][0 ], axis = - 1 ))
155+
156+ # First call with index=None should compute and lock the fit
157+ pred = em_model .fit_predict (index = None )
158+ assert em_model ._locked_fit is not None
159+ assert np .allclose (pred , em_model ._locked_fit )
160+ assert np .allclose (pred , avg_func (_data , axis = - 1 ))
161+ # Calling again returns the locked fit
162+ pred2 = em_model .fit_predict (index = None )
163+ assert pred2 is pred
164+
165+
115166def test_average_model ():
116167 """Check the implementation of the average DW model."""
117168
0 commit comments