Skip to content

Commit a9b3338

Browse files
committed
ENH: Validate data objects' attributes at instantiation
Validate data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities.
1 parent d27ba75 commit a9b3338

File tree

6 files changed

+351
-8
lines changed

6 files changed

+351
-8
lines changed

src/nifreeze/data/base.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,21 @@
4343

4444
ImageGrid = namedtuple("ImageGrid", ("shape", "affine"))
4545

46+
DATAOBJ_ABSENCE_ERROR_MSG = "BaseDataset 'dataobj' may not be None"
47+
"""BaseDataset initialization dataobj absence error message."""
48+
49+
DATAOBJ_NDIM_ERROR_MSG = "BaseDataset 'dataobj' must be a 4-D numpy array"
50+
"""BaseDataset initialization dataobj dimensionality error message."""
51+
52+
AFFINE_ABSENCE_ERROR_MSG = "BaseDataset 'affine' may not be None"
53+
"""BaseDataset initialization affine absence error message."""
54+
55+
AFFINE_SHAPE_ERROR_MSG = "BaseDataset 'affine' must be a 2-D numpy array (4 x 4)"
56+
"""BaseDataset initialization affine shape error message."""
57+
58+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG = "BaseDataset brainmask shape ({brainmask_shape}) does not match dataset volumes ({data_shapw})."
59+
"""BaseDataset brainmask shape mismatch error message."""
60+
4661

4762
def _data_repr(value: np.ndarray | None) -> str:
4863
if value is None:
@@ -57,6 +72,20 @@ def _cmp(lh: Any, rh: Any) -> bool:
5772
return lh == rh
5873

5974

75+
def _dataobj_validator(inst, attr, value) -> None:
76+
if value is None:
77+
raise ValueError(DATAOBJ_ABSENCE_ERROR_MSG)
78+
if not isinstance(value, np.ndarray) or value.ndim != 4:
79+
raise ValueError(DATAOBJ_NDIM_ERROR_MSG)
80+
81+
82+
def _affine_validator(inst, attr, value) -> None:
83+
if value is None:
84+
raise ValueError(AFFINE_ABSENCE_ERROR_MSG)
85+
if not isinstance(value, np.ndarray) or value.shape != (4, 4):
86+
raise ValueError(AFFINE_SHAPE_ERROR_MSG)
87+
88+
6089
@attrs.define(slots=True)
6190
class BaseDataset(Generic[Unpack[Ts]]):
6291
"""
@@ -74,9 +103,13 @@ class BaseDataset(Generic[Unpack[Ts]]):
74103
75104
"""
76105

77-
dataobj: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
106+
dataobj: np.ndarray = attrs.field(
107+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_dataobj_validator
108+
)
78109
"""A :obj:`~numpy.ndarray` object for the data array."""
79-
affine: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
110+
affine: np.ndarray = attrs.field(
111+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_affine_validator
112+
)
80113
"""Best affine for RAS-to-voxel conversion of coordinates (NIfTI header)."""
81114
brainmask: np.ndarray | None = attrs.field(
82115
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
@@ -94,6 +127,20 @@ class BaseDataset(Generic[Unpack[Ts]]):
94127
)
95128
"""A path to an HDF5 file to store the whole dataset."""
96129

130+
def __attrs_post_init__(self) -> None:
131+
"""Check h that rely on the fully initialized object.
132+
133+
- brainmask (if present) must match spatial shape of dataobj.
134+
"""
135+
136+
if self.brainmask is not None:
137+
if self.brainmask.shape != tuple(self.dataobj.shape[:3]):
138+
raise ValueError(
139+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG.format(
140+
brainmask_shape=self.brainmask.shape, data_shape=self.dataobj.shape[:3]
141+
)
142+
)
143+
97144
def __len__(self) -> int:
98145
"""Obtain the number of volumes/frames in the dataset."""
99146
return self.dataobj.shape[-1]

src/nifreeze/data/dmri.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@
3939
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
4040
from nifreeze.utils.ndimage import get_data, load_api
4141

42+
GRADIENT_ABSENCE_ERROR_MSG = "DWI 'gradients' may not be None"
43+
"""DWI initialization gradient absence error message."""
44+
45+
GRADIENT_SHAPE_ERROR_MSG = "DWI 'gradients' must be a 2-D numpy array (4 x N)"
46+
"""DWI initialization gradient shape error message."""
47+
48+
GRADIENT_COUNT_MISMATCH_ERROR_MSG = (
49+
"DWI gradients count ({n_gradients}) does not match dataset volumes ({data_vols})."
50+
)
51+
"""DWI initialization gradient count mismatch error message."""
52+
4253
DEFAULT_CLIP_PERCENTILE = 75
4354
"""Upper percentile threshold for intensity clipping."""
4455

@@ -64,17 +75,54 @@
6475
"""Minimum number of nonzero b-values in a DWI dataset."""
6576

6677

78+
def _gradients_validator(inst, attr, value) -> None:
79+
if value is None:
80+
raise ValueError(GRADIENT_ABSENCE_ERROR_MSG)
81+
if not isinstance(value, np.ndarray) or value.shape[0] != 4:
82+
raise ValueError(GRADIENT_SHAPE_ERROR_MSG)
83+
84+
6785
@attrs.define(slots=True)
6886
class DWI(BaseDataset[np.ndarray]):
6987
"""Data representation structure for dMRI data."""
7088

7189
bzero: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
7290
"""A *b=0* reference map, preferably obtained by some smart averaging."""
73-
gradients: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
91+
gradients: np.ndarray = attrs.field(
92+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_gradients_validator
93+
)
7494
"""A 2D numpy array of the gradient table (4xN)."""
7595
eddy_xfms: list = attrs.field(default=None)
7696
"""List of transforms to correct for estimated eddy current distortions."""
7797

98+
def __attrs_post_init__(self) -> None:
99+
"""Enforce presence and basic consistency of required dMRI fields at
100+
instantiation time.
101+
102+
Specifically, the number of gradient directions must match the last
103+
dimension of the data (number of volumes).
104+
"""
105+
106+
# If the data object exists and has a time/volume axis, ensure sizes
107+
# match.
108+
data_vols = None
109+
if getattr(self, "dataobj", None) is not None:
110+
shape = getattr(self.dataobj, "shape", None)
111+
if isinstance(shape, (tuple, list)) and len(shape) >= 1:
112+
try:
113+
data_vols = int(shape[-1])
114+
except (TypeError, ValueError):
115+
data_vols = None
116+
117+
if data_vols is not None:
118+
n_gradients = self.gradients.shape[1]
119+
if n_gradients != data_vols:
120+
raise ValueError(
121+
GRADIENT_COUNT_MISMATCH_ERROR_MSG.format(
122+
n_gradients=n_gradients, data_vols=data_vols
123+
)
124+
)
125+
78126
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
79127
return (self.gradients[..., idx],)
80128

src/nifreeze/data/pet.py

Lines changed: 56 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,17 +41,70 @@
4141
from nifreeze.utils.ndimage import load_api
4242

4343

44+
ARRAY_ATTRIBUTE_SHAPE_ERROR_MSG = "PET {attribute} must be a 1-D numpy array."
45+
"""PET array attribute shape error message."""
46+
47+
SCALAR_ATTRIBUTE_ERROR_MSG = "PET {attribute} must be a scalar."
48+
"""PET scalar attribute shape error message."""
49+
50+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG = (
51+
"PET {attribute} length ({attr_len}) does not match number of frames ({data_frames})"
52+
)
53+
"""PET attribute shape mismatch error message."""
54+
55+
56+
def _1d_array_validator(inst, attr, value) -> None:
57+
if not isinstance(value, np.ndarray) or value.ndim != 1:
58+
raise ValueError(ARRAY_ATTRIBUTE_SHAPE_ERROR_MSG.format(attribute=attr.name))
59+
60+
61+
def _scalar_validator(inst, attr, value) -> None:
62+
if not isinstance(value, (int, float, np.integer, np.floating)):
63+
raise ValueError(SCALAR_ATTRIBUTE_ERROR_MSG.format(attribute=attr.name))
64+
65+
4466
@attrs.define(slots=True)
4567
class PET(BaseDataset[np.ndarray]):
4668
"""Data representation structure for PET data."""
4769

48-
midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
70+
midframe: np.ndarray = attrs.field(
71+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_1d_array_validator
72+
)
4973
"""A (N,) numpy array specifying the midpoint timing of each sample or frame."""
50-
total_duration: float = attrs.field(default=None, repr=True)
74+
total_duration: float = attrs.field(default=None, repr=True, validator=_scalar_validator)
5175
"""A float representing the total duration of the dataset."""
52-
uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
76+
uptake: np.ndarray = attrs.field(
77+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_1d_array_validator
78+
)
5379
"""A (N,) numpy array specifying the uptake value of each sample or frame."""
5480

81+
def __attrs_post_init__(self) -> None:
82+
"""Enforce presence and basic consistency of required PET fields at
83+
instantiation time.
84+
85+
Specifically, the length of the midframe and uptake attributes must
86+
match the last dimension of the data (number of frames).
87+
"""
88+
data_frames = int(self.dataobj.shape[-1])
89+
90+
if len(self.midframe) != data_frames:
91+
raise ValueError(
92+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
93+
attribute=attrs.fields_dict(self.__class__).get("midframe").name,
94+
attr_len=len(self.midframe),
95+
data_frames=data_frames,
96+
)
97+
)
98+
99+
if len(self.uptake) != data_frames:
100+
raise ValueError(
101+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
102+
attribute=attrs.fields_dict(self.__class__).get("uptake").name,
103+
attr_len=len(self.uptake),
104+
data_frames=data_frames,
105+
)
106+
)
107+
55108
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
56109
return (self.midframe[idx],)
57110

test/test_data_base.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#
2323
"""Test dataset base class."""
2424

25+
import re
2526
from pathlib import Path
2627
from tempfile import TemporaryDirectory
2728
from typing import Any
@@ -31,8 +32,17 @@
3132
import pytest
3233

3334
from nifreeze.data import NFDH5_EXT, BaseDataset, load
35+
from nifreeze.data.base import (
36+
AFFINE_ABSENCE_ERROR_MSG,
37+
AFFINE_SHAPE_ERROR_MSG,
38+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG,
39+
DATAOBJ_ABSENCE_ERROR_MSG,
40+
DATAOBJ_NDIM_ERROR_MSG,
41+
)
3442
from nifreeze.utils.ndimage import get_data
3543

44+
from test.conftest import setup_random_uniform_ndim_data
45+
3646
DEFAULT_RANDOM_DATASET_SHAPE = (32, 32, 32, 5)
3747
DEFAULT_RANDOM_DATASET_SIZE = int(np.prod(DEFAULT_RANDOM_DATASET_SHAPE[:3]))
3848

@@ -51,6 +61,51 @@ def random_dataset(setup_random_uniform_spatial_data) -> BaseDataset:
5161
return BaseDataset(dataobj=data, affine=affine)
5262

5363

64+
def test_missing_dataobj_error():
65+
with pytest.raises(ValueError, match=DATAOBJ_ABSENCE_ERROR_MSG):
66+
BaseDataset()
67+
68+
69+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4, 6), 0.0, 1.0)
70+
def test_dataobj_ndim_error(setup_random_uniform_spatial_data):
71+
data, _ = setup_random_uniform_spatial_data
72+
with pytest.raises(ValueError, match=DATAOBJ_NDIM_ERROR_MSG):
73+
BaseDataset(dataobj=data)
74+
75+
76+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
77+
def test_missing_affine_error(setup_random_uniform_spatial_data):
78+
data, _ = setup_random_uniform_spatial_data
79+
with pytest.raises(ValueError, match=DATAOBJ_ABSENCE_ERROR_MSG):
80+
BaseDataset(dataobj=data)
81+
82+
83+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
84+
@pytest.mark.parametrize("size", ((2, 2), (3, 4), (4, 3), (5, 5)))
85+
def test_affine_shape_error(setup_random_uniform_ndim_data, size):
86+
data = setup_random_uniform_ndim_data
87+
affine = np.ones(size)
88+
with pytest.raises(ValueError, match=re.escape(AFFINE_SHAPE_ERROR_MSG)):
89+
BaseDataset(dataobj=data, affine=affine)
90+
91+
92+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
93+
def test_brainmask_volume_mismatch_error(request, setup_random_uniform_spatial_data):
94+
data, affine = setup_random_uniform_spatial_data
95+
data_shape = data.shape[:3]
96+
brainmask_size = tuple(map(lambda x: x + 1, data_shape))
97+
brainmask = request.node.rng.choice([True, False], size=brainmask_size)
98+
with pytest.raises(
99+
ValueError,
100+
match=re.escape(
101+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG.format(
102+
brainmask_shape=brainmask.shape, data_shapw=data_shape
103+
)
104+
),
105+
):
106+
BaseDataset(dataobj=data, affine=affine, brainmask=brainmask)
107+
108+
54109
def test_base_dataset_init(random_dataset: BaseDataset):
55110
"""Test that the BaseDataset can be initialized with random data."""
56111
assert random_dataset.dataobj is not None

test/test_data_dmri.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,23 @@
2222
#
2323
"""Unit tests exercising the dMRI data structure."""
2424

25+
import re
2526
from pathlib import Path
2627

2728
import nibabel as nb
2829
import numpy as np
2930
import pytest
3031

3132
from nifreeze.data import load
32-
from nifreeze.data.dmri import DWI, find_shelling_scheme, from_nii, transform_fsl_bvec
33+
from nifreeze.data.dmri import (
34+
GRADIENT_COUNT_MISMATCH_ERROR_MSG,
35+
GRADIENT_ABSENCE_ERROR_MSG,
36+
GRADIENT_SHAPE_ERROR_MSG,
37+
DWI,
38+
find_shelling_scheme,
39+
from_nii,
40+
transform_fsl_bvec,
41+
)
3342
from nifreeze.utils.ndimage import load_api
3443

3544

@@ -77,6 +86,36 @@ def test_main(datadir):
7786
assert isinstance(load(input_file), DWI)
7887

7988

89+
@pytest.mark.random_uniform_spatial_data((2, 2, 2), 0.0, 1.0)
90+
def test_missing_gradients_error(setup_random_uniform_spatial_data):
91+
data, affine = setup_random_uniform_spatial_data
92+
with pytest.raises(ValueError, match=GRADIENT_ABSENCE_ERROR_MSG):
93+
DWI(dataobj=data, affine=affine)
94+
95+
96+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
97+
def test_gradients_shape_error(setup_random_uniform_spatial_data):
98+
data, affine = setup_random_uniform_spatial_data
99+
gradients = np.zeros((3, data.shape[-1]))
100+
with pytest.raises(ValueError, match=re.escape(GRADIENT_SHAPE_ERROR_MSG)):
101+
DWI(dataobj=data, affine=affine, gradients=gradients)
102+
103+
104+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
105+
def test_gradients_volume_mismatch_error(setup_random_uniform_spatial_data):
106+
data, affine = setup_random_uniform_spatial_data
107+
data_vols = data.shape[-1]
108+
n_gradients = data_vols + 1
109+
gradients = np.zeros((4, n_gradients))
110+
with pytest.raises(
111+
ValueError,
112+
match=re.escape(
113+
GRADIENT_COUNT_MISMATCH_ERROR_MSG.format(n_gradients=n_gradients, data_vols=data_vols)
114+
),
115+
):
116+
DWI(dataobj=data, affine=affine, gradients=gradients)
117+
118+
80119
@pytest.mark.parametrize("insert_b0", (False, True))
81120
@pytest.mark.parametrize("rotate_bvecs", (False, True))
82121
def test_load(datadir, tmp_path, insert_b0, rotate_bvecs): # noqa: C901

0 commit comments

Comments
 (0)