From 62f866cb1dc405ddef0a29627dd1b3c267f28a29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Sun, 17 Aug 2025 22:18:16 -0400 Subject: [PATCH] REF: Inherit base model class from `ABC` to enforce native abstraction Inherit base model class from `ABC` to enforce native abstraction. Fixes: ``` Abstract methods are allowed in classes whose metaclass is 'ABCMeta' ``` raised locally by the IDE. --- src/nifreeze/model/base.py | 8 +++++--- test/test_model.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/nifreeze/model/base.py b/src/nifreeze/model/base.py index d4f944bc..540ee812 100644 --- a/src/nifreeze/model/base.py +++ b/src/nifreeze/model/base.py @@ -22,7 +22,7 @@ # """Base infrastructure for nifreeze's models.""" -from abc import abstractmethod +from abc import ABC, ABCMeta, abstractmethod from typing import Union from warnings import warn @@ -77,7 +77,7 @@ def init(model: str | None = None, **kwargs): raise NotImplementedError(f"Unsupported model <{model}>.") -class BaseModel: +class BaseModel(ABC): """ Defines the interface and default methods. @@ -88,6 +88,8 @@ class BaseModel: """ + __metaclass__ = ABCMeta + __slots__ = ("_dataset", "_locked_fit") def __init__(self, dataset, **kwargs): @@ -116,7 +118,7 @@ def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, N If ``None``, no prediction will be executed. """ - raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.") + return None class TrivialModel(BaseModel): diff --git a/test/test_model.py b/test/test_model.py index a4545860..b4c0924f 100644 --- a/test/test_model.py +++ b/test/test_model.py @@ -35,6 +35,16 @@ from nifreeze.testing import simulations as _sim +def test_base_model(): + from nifreeze.model.base import BaseModel + + with pytest.raises( + TypeError, + match="Can't instantiate abstract class BaseModel without an implementation for abstract method 'fit_predict'", + ): + BaseModel(None) + + @pytest.mark.parametrize("use_mask", (False, True)) def test_trivial_model(request, use_mask): """Check the implementation of the trivial B0 model."""