diff --git a/test/test_model.py b/test/test_model.py index deea21767..53791baef 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -304,6 +304,7 @@ def test_model_factory_invalid_model(): @pytest.mark.parametrize( "name, expected_cls", [ + ("trivial", model.TrivialModel), ("avg", model.ExpectationModel), ("average", model.ExpectationModel), ("mean", model.ExpectationModel), @@ -318,6 +319,14 @@ def test_factory_variants(name, expected_cls, setup_random_base_data): motion_affines=motion_affines, datahdr=datahdr, ) + if name == "trivial": + # ToDo + # Does not work because BaseDataset does not have a reference property + # ToDo + # Test the bzero case + # ToDo + # Test the predicted case + setattr(dataset, "reference", np.zeros(dataobj.shape[:3])) model_instance = model.ModelFactory.init(name, dataset=dataset) assert isinstance(model_instance, expected_cls)