Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 211 additions & 2 deletions src/nifreeze/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,125 @@

ImageGrid = namedtuple("ImageGrid", ("shape", "affine"))

DATAOBJ_ABSENCE_ERROR_MSG = "BaseDataset 'dataobj' may not be None"
"""BaseDataset initialization dataobj absence error message."""

DATAOBJ_OBJECT_ERROR_MSG = "BaseDataset 'dataobj' must be a numpy array."
"""BaseDataset initialization dataobj object error message."""

DATAOBJ_NDIM_ERROR_MSG = "BaseDataset 'dataobj' must be a 4-D numpy array"
"""BaseDataset initialization dataobj dimensionality error message."""

AFFINE_ABSENCE_ERROR_MSG = "BaseDataset 'affine' may not be None"
"""BaseDataset initialization affine absence error message."""

AFFINE_OBJECT_ERROR_MSG = "BaseDataset 'affine' must be a numpy array."
"""BaseDataset initialization affine object error message."""

AFFINE_NDIM_ERROR_MSG = "BaseDataset 'affine' must be a 2D array"
"""Affine dimensionality error message."""

AFFINE_SHAPE_ERROR_MSG = "BaseDataset 'affine' must be a 2-D numpy array (4 x 4)"
"""BaseDataset initialization affine shape error message."""

BRAINMASK_SHAPE_MISMATCH_ERROR_MSG = "BaseDataset 'brainmask' shape ({brainmask_shape}) does not match dataset volumes ({data_shape})."
"""BaseDataset brainmask shape mismatch error message."""


def _has_dim_size(value: Any, size: int) -> bool:
"""Return ``True`` if ``value`` has a ``.shape`` attribute and one of its
dimensions equals ``size``.

This is useful for checks where at least one axis must match an expected
length. It does not require a specific axis index; it only verifies presence
of the size in any axis in ``.shape``.

Parameters
----------
value : :obj:`Any`
Object to inspect. Typical inputs are NumPy arrays or objects exposing
``.shape``.
size : :obj:`int`
The required dimension size to look for in ``value.shape``.

Returns
-------
:obj:`bool`
``True`` if ``.shape`` exists and any of its integers equals ``size``,
``False`` otherwise.

Examples
--------
>>> _has_dim_size(np.zeros((10, 3)), 3)
True
>>> _has_dim_size(np.zeros((4, 5)), 6)
False
"""

shape = getattr(value, "shape", None)
if shape is None:
return False
# Shape may be an object that is not iterable; handle TypeError explicitly
try:
return size in tuple(shape)
except TypeError:
return False


def _has_ndim(value: Any, ndim: int) -> bool:
"""Check if ``value`` has ``ndim`` dimensionality.

Returns ``True`` if `value` has an ``.ndim`` attribute equal to ``ndim``, or
if it has a ``.shape`` attribute whose length equals ``ndim``.

This helper is tolerant: it accepts objects that either:
- expose an integer ``.ndim`` attribute (e.g., NumPy arrays), or
- expose a ``.shape`` attribute (sequence/tuple-like) whose length equals
``ndim``.

Parameters
----------
value : :obj:`Any`
Object to inspect for dimensionality. Typical inputs are NumPy arrays,
array-likes, or objects that provide ``.ndim`` / ``.shape``.
ndim : :obj:`int`
The required dimensionality.

Returns
-------
:obj:`bool`
``True`` if ``value`` appears to have ``ndim`` dimensions, ``False``
otherwise.

Examples
--------
>>> _has_ndim(np.zeros((2, 3)), 2)
True
>>> _has_ndim(np.zeros((3,)), 2)
False
>>> class WithShape:
... shape = (2, 2, 2)
>>> _has_ndim(WithShape(), 3)
True
"""

# Prefer .ndim if available
ndim_attr = getattr(value, "ndim", None)
if ndim_attr is not None:
try:
return int(ndim_attr) == ndim
except (TypeError, ValueError):
return False

# Fallback to checking shape length
shape = getattr(value, "shape", None)
if shape is None:
return False
try:
return len(tuple(shape)) == ndim
except TypeError:
return False


def _data_repr(value: np.ndarray | None) -> str:
if value is None:
Expand All @@ -58,6 +177,76 @@ def _cmp(lh: Any, rh: Any) -> bool:
return lh == rh


def _dataobj_validator(inst: BaseDataset, attr: attrs.Attribute, value: Any) -> None:
"""Strict validator for data objects.

It enforces that ``value`` is present and is a NumPy array with exactly 4
dimensions (``ndim == 4``).

This function is intended for use as an attrs-style validator.

Parameters
----------
inst : :obj:`~nifreeze.data.base.BaseDataset`
The instance being validated (unused, present for validator signature).
attr : :obj:`attrs.Attribute`
The attribute being validated (unused, present for validator signature).
value : :obj:`Any`
The value to validate.

Raises
------
exc:`TypeError`
If the input cannot be converted to a float :obj:`~numpy.ndarray`.
exc:`ValueError`
If the value is ``None``, or not 4-dimensional.
"""
if value is None:
raise ValueError(DATAOBJ_ABSENCE_ERROR_MSG)

if not isinstance(value, np.ndarray):
raise TypeError(DATAOBJ_OBJECT_ERROR_MSG)

if not _has_ndim(value, 4):
raise ValueError(DATAOBJ_NDIM_ERROR_MSG)


def _affine_validator(inst: BaseDataset, attr: attrs.Attribute, value: Any) -> None:
"""Strict validator for affine matrices.

It enforces that ``value`` is present and is a 4x4 NumPy array.

This function is intended for use as an attrs-style validator.

Parameters
----------
inst : :obj:`~nifreeze.data.base.BaseDataset`
The instance being validated (unused, present for validator signature).
attr : :obj:`attrs.Attribute`
The attribute being validated (unused, present for validator signature).
value : :obj:`Any`
The value to validate.

Raises
------
exc:`TypeError`
If the input cannot be converted to a float :obj:`~numpy.ndarray`.
exc:`ValueError`
If the value is ``None``, or not shaped ``(4, 4)``.
"""
if value is None:
raise ValueError(AFFINE_ABSENCE_ERROR_MSG)

if not isinstance(value, np.ndarray):
raise TypeError(AFFINE_OBJECT_ERROR_MSG)

if not _has_ndim(value, 2):
raise ValueError(AFFINE_NDIM_ERROR_MSG)

if value.shape != (4, 4):
raise ValueError(AFFINE_SHAPE_ERROR_MSG)


@attrs.define(slots=True)
class BaseDataset(Generic[Unpack[Ts]]):
"""
Expand All @@ -75,9 +264,13 @@ class BaseDataset(Generic[Unpack[Ts]]):

"""

dataobj: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
dataobj: np.ndarray = attrs.field(
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_dataobj_validator
)
"""A :obj:`~numpy.ndarray` object for the data array."""
affine: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
affine: np.ndarray = attrs.field(
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_affine_validator
)
"""Best affine for RAS-to-voxel conversion of coordinates (NIfTI header)."""
brainmask: np.ndarray | None = attrs.field(
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
Expand All @@ -95,6 +288,22 @@ class BaseDataset(Generic[Unpack[Ts]]):
)
"""A path to an HDF5 file to store the whole dataset."""

def __attrs_post_init__(self) -> None:
"""Enforce basic consistency of base dataset fields at instantiation
time.

Specifically, the brainmask (if present) must match spatial shape of
dataobj.
"""

if self.brainmask is not None:
if self.brainmask.shape != tuple(self.dataobj.shape[:3]):
raise ValueError(
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG.format(
brainmask_shape=self.brainmask.shape, data_shape=self.dataobj.shape[:3]
)
)

def __len__(self) -> int:
"""Obtain the number of volumes/frames in the dataset."""
return self.dataobj.shape[-1]
Expand Down
117 changes: 117 additions & 0 deletions test/test_data_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#
"""Test dataset base class."""

import re
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any
Expand All @@ -31,6 +32,18 @@
import pytest

from nifreeze.data import NFDH5_EXT, BaseDataset, load
from nifreeze.data.base import (
AFFINE_ABSENCE_ERROR_MSG,
AFFINE_NDIM_ERROR_MSG,
AFFINE_OBJECT_ERROR_MSG,
AFFINE_SHAPE_ERROR_MSG,
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG,
DATAOBJ_ABSENCE_ERROR_MSG,
DATAOBJ_NDIM_ERROR_MSG,
DATAOBJ_OBJECT_ERROR_MSG,
_has_dim_size,
_has_ndim,
)
from nifreeze.utils.ndimage import get_data

DEFAULT_RANDOM_DATASET_SHAPE = (32, 32, 32, 5)
Expand All @@ -51,6 +64,110 @@ def random_dataset(setup_random_uniform_spatial_data) -> BaseDataset:
return BaseDataset(dataobj=data, affine=affine)


@pytest.mark.parametrize(
"value, size, expected",
[
(np.zeros((2, 4, 5)), 4, True),
(np.zeros((2, 4, 5)), 6, False),
# Objects without .shape
([1, 2, 3], 3, False),
# Shape that is not iterable
(
type("BadShape", (), {"shape": 5})(),
5,
False,
),
],
)
def test_has_dim_size(value, size, expected):
assert _has_dim_size(value, size) is expected


@pytest.mark.parametrize(
"obj_factory, ndim, expected",
[
(lambda: type("WithNdim", (), {"ndim": 2})(), 2, True),
(lambda: type("WithNdim", (), {"ndim": 2})(), 3, False),
(lambda: type("BadNdim", (), {"ndim": "not-an-int"})(), 2, False),
(lambda: type("WithShape", (), {"shape": (3, 4)})(), 2, True),
(lambda: (123), 1, False), # No ndim or shape
],
)
def test_has_ndim(obj_factory, ndim, expected):
obj = obj_factory()
assert _has_ndim(obj, ndim) is expected


@pytest.mark.parametrize(
"value, expected_exc, expected_msg",
[
(None, ValueError, DATAOBJ_ABSENCE_ERROR_MSG),
(1, TypeError, DATAOBJ_OBJECT_ERROR_MSG),
],
)
def test_dataobj_basic_errors(value, expected_exc, expected_msg):
with pytest.raises(expected_exc, match=str(expected_msg)):
BaseDataset(dataobj=value) # type: ignore[arg-type]


@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4, 6), 0.0, 1.0)
def test_dataobj_ndim_error(setup_random_uniform_spatial_data):
data, _ = setup_random_uniform_spatial_data
with pytest.raises(ValueError, match=DATAOBJ_NDIM_ERROR_MSG):
BaseDataset(dataobj=data)


@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
@pytest.mark.parametrize(
"affine, expected_exc, expected_msg",
[
(None, ValueError, AFFINE_ABSENCE_ERROR_MSG),
(1, TypeError, AFFINE_OBJECT_ERROR_MSG),
],
)
def test_missing_affine_error(
setup_random_uniform_spatial_data, affine, expected_exc, expected_msg
):
data, _ = setup_random_uniform_spatial_data
with pytest.raises(expected_exc, match=str(expected_msg)):
BaseDataset(dataobj=data, affine=affine) # type: ignore[arg-type]


@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
@pytest.mark.parametrize("size", ((2,), (3, 4, 2)))
def test_affine_ndim_error(setup_random_uniform_ndim_data, size):
data = setup_random_uniform_ndim_data
affine = np.ones(size)
with pytest.raises(ValueError, match=re.escape(AFFINE_NDIM_ERROR_MSG)):
BaseDataset(dataobj=data, affine=affine)


@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
@pytest.mark.parametrize("size", ((2, 2), (3, 4), (4, 3), (5, 5)))
def test_affine_shape_error(setup_random_uniform_ndim_data, size):
data = setup_random_uniform_ndim_data
affine = np.ones(size)
with pytest.raises(ValueError, match=re.escape(AFFINE_SHAPE_ERROR_MSG)):
BaseDataset(dataobj=data, affine=affine)


@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
def test_brainmask_volume_mismatch_error(request, setup_random_uniform_spatial_data):
data, affine = setup_random_uniform_spatial_data
data_shape = data.shape[:3]
brainmask_size = tuple(map(lambda x: x + 1, data_shape))
brainmask = request.node.rng.choice([True, False], size=brainmask_size)
with pytest.raises(
ValueError,
match=re.escape(
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG.format(
brainmask_shape=brainmask.shape, data_shape=data_shape
)
),
):
BaseDataset(dataobj=data, affine=affine, brainmask=brainmask)


def test_base_dataset_init(random_dataset: BaseDataset):
"""Test that the BaseDataset can be initialized with random data."""
assert random_dataset.dataobj is not None
Expand Down
3 changes: 2 additions & 1 deletion test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def test_trivial_model(request, use_mask):

data = DWI(
dataobj=rng.normal(size=(*_S0.shape, 10)),
affine=np.eye(4),
bzero=_clipped_S0,
brainmask=mask,
)
Expand Down Expand Up @@ -197,7 +198,7 @@ def test_average_model():
mask = np.ones(size[:3], dtype=bool)

data *= gtab[:, -1]
dataset = DWI(dataobj=data, gradients=gtab, brainmask=mask)
dataset = DWI(dataobj=data, affine=np.eye(4), gradients=gtab, brainmask=mask)

avgmodel_mean = model.AverageDWIModel(dataset, stat="mean")
avgmodel_mean_full = model.AverageDWIModel(dataset, stat="mean", atol_low=2000, atol_high=2000)
Expand Down