Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/nifreeze/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#
"""Base infrastructure for nifreeze's models."""

from abc import abstractmethod
from abc import ABC, ABCMeta, abstractmethod
from typing import Union
from warnings import warn

Expand Down Expand Up @@ -77,7 +77,7 @@ def init(model: str | None = None, **kwargs):
raise NotImplementedError(f"Unsupported model <{model}>.")


class BaseModel:
class BaseModel(ABC):
"""
Defines the interface and default methods.
Expand All @@ -88,6 +88,8 @@ class BaseModel:
"""

__metaclass__ = ABCMeta

__slots__ = ("_dataset", "_locked_fit")

def __init__(self, dataset, **kwargs):
Expand Down Expand Up @@ -116,7 +118,7 @@ def fit_predict(self, index: int | None = None, **kwargs) -> Union[np.ndarray, N
If ``None``, no prediction will be executed.
"""
raise NotImplementedError("Cannot call fit_predict() on a BaseModel instance.")
return None


class TrivialModel(BaseModel):
Expand Down
10 changes: 10 additions & 0 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,16 @@
from nifreeze.testing import simulations as _sim


def test_base_model():
from nifreeze.model.base import BaseModel

with pytest.raises(
TypeError,
match="Can't instantiate abstract class BaseModel without an implementation for abstract method 'fit_predict'",
):
BaseModel(None)


@pytest.mark.parametrize("use_mask", (False, True))
def test_trivial_model(request, use_mask):
"""Check the implementation of the trivial B0 model."""
Expand Down