diff --git a/src/nifreeze/model/base.py b/src/nifreeze/model/base.py index d4f944bc..540ee812 100644 --- a/src/nifreeze/model/base.py +++ b/src/nifreeze/model/base.py @@ -22,7 +22,7 @@ # """Base infrastructure for nifreeze's models.""" -from abc import abstractmethod +from abc import ABC, ABCMeta, abstractmethod from typing import Union from warnings import warn @@ -77,7 +77,7 @@ def init(model: str | None = None, **kwargs): raise NotImplementedError(f"Unsupported model <{model}>.") -class BaseModel: +class BaseModel(ABC): """ Defines the interface and default methods. @@ -88,6 +88,8 @@ class BaseModel: """ + __metaclass__ = ABCMeta + __slots__ = ("_dataset", "_locked_fit") def __init__(self, dataset, **kwargs): @@ -116,7 +118,7 @@ def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, N If ``None``, no prediction will be executed. """ - raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.") + return None class TrivialModel(BaseModel): diff --git a/test/test_model.py b/test/test_model.py index a4545860..b4c0924f 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -35,6 +35,16 @@ from nifreeze.testing import simulations as _sim +def test_base_model(): + from nifreeze.model.base import BaseModel + + with pytest.raises( + TypeError, + match="Can't instantiate abstract class BaseModel without an implementation for abstract method 'fit_predict'", + ): + BaseModel(None) + + @pytest.mark.parametrize("use_mask", (False, True)) def test_trivial_model(request, use_mask): """Check the implementation of the trivial B0 model."""