Skip to content

Commit 702e7f2

Browse files
committed
ENH: Validate data objects' attributes at instantiation
Validate data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities.
1 parent 997df63 commit 702e7f2

File tree

6 files changed

+349
-8
lines changed

6 files changed

+349
-8
lines changed

src/nifreeze/data/base.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,21 @@
4444

4545
ImageGrid = namedtuple("ImageGrid", ("shape", "affine"))
4646

47+
DATAOBJ_ABSENCE_ERROR_MSG = "BaseDataset 'dataobj' may not be None"
48+
"""BaseDataset initialization dataobj absence error message."""
49+
50+
DATAOBJ_NDIM_ERROR_MSG = "BaseDataset 'dataobj' must be a 4-D numpy array"
51+
"""BaseDataset initialization dataobj dimensionality error message."""
52+
53+
AFFINE_ABSENCE_ERROR_MSG = "BaseDataset 'affine' may not be None"
54+
"""BaseDataset initialization affine absence error message."""
55+
56+
AFFINE_SHAPE_ERROR_MSG = "BaseDataset 'affine' must be a 2-D numpy array (4 x 4)"
57+
"""BaseDataset initialization affine shape error message."""
58+
59+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG = "BaseDataset brainmask shape ({brainmask_shape}) does not match dataset volumes ({data_shape})."
60+
"""BaseDataset brainmask shape mismatch error message."""
61+
4762

4863
def _data_repr(value: np.ndarray | None) -> str:
4964
if value is None:
@@ -58,6 +73,20 @@ def _cmp(lh: Any, rh: Any) -> bool:
5873
return lh == rh
5974

6075

76+
def _dataobj_validator(inst, attr, value) -> None:
77+
if value is None:
78+
raise ValueError(DATAOBJ_ABSENCE_ERROR_MSG)
79+
if not isinstance(value, np.ndarray) or value.ndim != 4:
80+
raise ValueError(DATAOBJ_NDIM_ERROR_MSG)
81+
82+
83+
def _affine_validator(inst, attr, value) -> None:
84+
if value is None:
85+
raise ValueError(AFFINE_ABSENCE_ERROR_MSG)
86+
if not isinstance(value, np.ndarray) or value.shape != (4, 4):
87+
raise ValueError(AFFINE_SHAPE_ERROR_MSG)
88+
89+
6190
@attrs.define(slots=True)
6291
class BaseDataset(Generic[Unpack[Ts]]):
6392
"""
@@ -75,9 +104,13 @@ class BaseDataset(Generic[Unpack[Ts]]):
75104
76105
"""
77106

78-
dataobj: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
107+
dataobj: np.ndarray = attrs.field(
108+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_dataobj_validator
109+
)
79110
"""A :obj:`~numpy.ndarray` object for the data array."""
80-
affine: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
111+
affine: np.ndarray = attrs.field(
112+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_affine_validator
113+
)
81114
"""Best affine for RAS-to-voxel conversion of coordinates (NIfTI header)."""
82115
brainmask: np.ndarray | None = attrs.field(
83116
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
@@ -95,6 +128,22 @@ class BaseDataset(Generic[Unpack[Ts]]):
95128
)
96129
"""A path to an HDF5 file to store the whole dataset."""
97130

131+
def __attrs_post_init__(self) -> None:
132+
"""Enforce basic consistency of base dataset fields at instantiation
133+
time.
134+
135+
Specifically, the brainmask (if present) must match spatial shape of
136+
dataobj.
137+
"""
138+
139+
if self.brainmask is not None:
140+
if self.brainmask.shape != tuple(self.dataobj.shape[:3]):
141+
raise ValueError(
142+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG.format(
143+
brainmask_shape=self.brainmask.shape, data_shape=self.dataobj.shape[:3]
144+
)
145+
)
146+
98147
def __len__(self) -> int:
99148
"""Obtain the number of volumes/frames in the dataset."""
100149
return self.dataobj.shape[-1]

src/nifreeze/data/dmri.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@
3939
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
4040
from nifreeze.utils.ndimage import get_data, load_api
4141

42+
GRADIENT_ABSENCE_ERROR_MSG = "DWI 'gradients' may not be None"
43+
"""DWI initialization gradient absence error message."""
44+
45+
GRADIENT_SHAPE_ERROR_MSG = "DWI 'gradients' must be a 2-D numpy array (N x 4)"
46+
"""DWI initialization gradient shape error message."""
47+
48+
GRADIENT_COUNT_MISMATCH_ERROR_MSG = (
49+
"DWI gradients count ({n_gradients}) does not match dataset volumes ({data_vols})."
50+
)
51+
"""DWI initialization gradient count mismatch error message."""
52+
4253
DEFAULT_CLIP_PERCENTILE = 75
4354
"""Upper percentile threshold for intensity clipping."""
4455

@@ -64,6 +75,13 @@
6475
"""Minimum number of nonzero b-values in a DWI dataset."""
6576

6677

78+
def _gradients_validator(inst, attr, value) -> None:
79+
if value is None:
80+
raise ValueError(GRADIENT_ABSENCE_ERROR_MSG)
81+
if not isinstance(value, np.ndarray) or value.shape[1] != 4:
82+
raise ValueError(GRADIENT_SHAPE_ERROR_MSG)
83+
84+
6785
@attrs.define(slots=True)
6886
class DWI(BaseDataset[np.ndarray]):
6987
"""Data representation structure for dMRI data."""
@@ -72,14 +90,43 @@ class DWI(BaseDataset[np.ndarray]):
7290
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
7391
)
7492
"""A *b=0* reference map, preferably obtained by some smart averaging."""
75-
gradients: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
93+
gradients: np.ndarray = attrs.field(
94+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_gradients_validator
95+
)
7696
"""A 2D numpy array of the gradient table (``N`` orientations x ``C`` components)."""
7797
eddy_xfms: list = attrs.field(default=None)
7898
"""List of transforms to correct for estimated eddy current distortions."""
7999

80100
def __attrs_post_init__(self) -> None:
101+
"""Enforce presence and basic consistency of required dMRI fields at
102+
instantiation time.
103+
104+
Specifically, the number of gradient directions must match the last
105+
dimension of the data (number of volumes).
106+
"""
107+
81108
self._normalize_gradients()
82109

110+
# If the data object exists and has a time/volume axis, ensure sizes
111+
# match.
112+
data_vols = None
113+
if getattr(self, "dataobj", None) is not None:
114+
shape = getattr(self.dataobj, "shape", None)
115+
if isinstance(shape, (tuple, list)) and len(shape) >= 1:
116+
try:
117+
data_vols = int(shape[-1])
118+
except (TypeError, ValueError):
119+
data_vols = None
120+
121+
if data_vols is not None:
122+
n_gradients = self.gradients.shape[1]
123+
if n_gradients != data_vols:
124+
raise ValueError(
125+
GRADIENT_COUNT_MISMATCH_ERROR_MSG.format(
126+
n_gradients=n_gradients, data_vols=data_vols
127+
)
128+
)
129+
83130
def _normalize_gradients(self) -> None:
84131
if self.gradients is None:
85132
return

src/nifreeze/data/pet.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,70 @@
4141
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
4242
from nifreeze.utils.ndimage import load_api
4343

44+
ARRAY_ATTRIBUTE_SHAPE_ERROR_MSG = "PET {attribute} must be a 1-D numpy array."
45+
"""PET array attribute shape error message."""
46+
47+
SCALAR_ATTRIBUTE_ERROR_MSG = "PET {attribute} must be a scalar."
48+
"""PET scalar attribute shape error message."""
49+
50+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG = (
51+
"PET {attribute} length ({attr_len}) does not match number of frames ({data_frames})"
52+
)
53+
"""PET attribute shape mismatch error message."""
54+
55+
56+
def _1d_array_validator(inst, attr, value) -> None:
57+
if not isinstance(value, np.ndarray) or value.ndim != 1:
58+
raise ValueError(ARRAY_ATTRIBUTE_SHAPE_ERROR_MSG.format(attribute=attr.name))
59+
60+
61+
def _scalar_validator(inst, attr, value) -> None:
62+
if not isinstance(value, (int, float, np.integer, np.floating)):
63+
raise ValueError(SCALAR_ATTRIBUTE_ERROR_MSG.format(attribute=attr.name))
64+
4465

4566
@attrs.define(slots=True)
4667
class PET(BaseDataset[np.ndarray]):
4768
"""Data representation structure for PET data."""
4869

49-
midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
70+
midframe: np.ndarray = attrs.field(
71+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_1d_array_validator
72+
)
5073
"""A (N,) numpy array specifying the midpoint timing of each sample or frame."""
51-
total_duration: float = attrs.field(default=None, repr=True)
74+
total_duration: float = attrs.field(default=None, repr=True, validator=_scalar_validator)
5275
"""A float representing the total duration of the dataset."""
53-
uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
76+
uptake: np.ndarray = attrs.field(
77+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_1d_array_validator
78+
)
5479
"""A (N,) numpy array specifying the uptake value of each sample or frame."""
5580

81+
def __attrs_post_init__(self) -> None:
82+
"""Enforce presence and basic consistency of required PET fields at
83+
instantiation time.
84+
85+
Specifically, the length of the midframe and uptake attributes must
86+
match the last dimension of the data (number of frames).
87+
"""
88+
data_frames = int(self.dataobj.shape[-1])
89+
90+
if len(self.midframe) != data_frames:
91+
raise ValueError(
92+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
93+
attribute=attrs.fields_dict(self.__class__)["midframe"].name,
94+
attr_len=len(self.midframe),
95+
data_frames=data_frames,
96+
)
97+
)
98+
99+
if len(self.uptake) != data_frames:
100+
raise ValueError(
101+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
102+
attribute=attrs.fields_dict(self.__class__)["uptake"].name,
103+
attr_len=len(self.uptake),
104+
data_frames=data_frames,
105+
)
106+
)
107+
56108
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
57109
return (self.midframe[idx],)
58110

test/test_data_base.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#
2323
"""Test dataset base class."""
2424

25+
import re
2526
from pathlib import Path
2627
from tempfile import TemporaryDirectory
2728
from typing import Any
@@ -31,6 +32,13 @@
3132
import pytest
3233

3334
from nifreeze.data import NFDH5_EXT, BaseDataset, load
35+
from nifreeze.data.base import (
36+
AFFINE_ABSENCE_ERROR_MSG,
37+
AFFINE_SHAPE_ERROR_MSG,
38+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG,
39+
DATAOBJ_ABSENCE_ERROR_MSG,
40+
DATAOBJ_NDIM_ERROR_MSG,
41+
)
3442
from nifreeze.utils.ndimage import get_data
3543

3644
DEFAULT_RANDOM_DATASET_SHAPE = (32, 32, 32, 5)
@@ -51,6 +59,51 @@ def random_dataset(setup_random_uniform_spatial_data) -> BaseDataset:
5159
return BaseDataset(dataobj=data, affine=affine)
5260

5361

62+
def test_missing_dataobj_error():
63+
with pytest.raises(ValueError, match=DATAOBJ_ABSENCE_ERROR_MSG):
64+
BaseDataset()
65+
66+
67+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4, 6), 0.0, 1.0)
68+
def test_dataobj_ndim_error(setup_random_uniform_spatial_data):
69+
data, _ = setup_random_uniform_spatial_data
70+
with pytest.raises(ValueError, match=DATAOBJ_NDIM_ERROR_MSG):
71+
BaseDataset(dataobj=data)
72+
73+
74+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
75+
def test_missing_affine_error(setup_random_uniform_spatial_data):
76+
data, _ = setup_random_uniform_spatial_data
77+
with pytest.raises(ValueError, match=AFFINE_ABSENCE_ERROR_MSG):
78+
BaseDataset(dataobj=data)
79+
80+
81+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
82+
@pytest.mark.parametrize("size", ((2, 2), (3, 4), (4, 3), (5, 5)))
83+
def test_affine_shape_error(setup_random_uniform_ndim_data, size):
84+
data = setup_random_uniform_ndim_data
85+
affine = np.ones(size)
86+
with pytest.raises(ValueError, match=re.escape(AFFINE_SHAPE_ERROR_MSG)):
87+
BaseDataset(dataobj=data, affine=affine)
88+
89+
90+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
91+
def test_brainmask_volume_mismatch_error(request, setup_random_uniform_spatial_data):
92+
data, affine = setup_random_uniform_spatial_data
93+
data_shape = data.shape[:3]
94+
brainmask_size = tuple(map(lambda x: x + 1, data_shape))
95+
brainmask = request.node.rng.choice([True, False], size=brainmask_size)
96+
with pytest.raises(
97+
ValueError,
98+
match=re.escape(
99+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG.format(
100+
brainmask_shape=brainmask.shape, data_shape=data_shape
101+
)
102+
),
103+
):
104+
BaseDataset(dataobj=data, affine=affine, brainmask=brainmask)
105+
106+
54107
def test_base_dataset_init(random_dataset: BaseDataset):
55108
"""Test that the BaseDataset can be initialized with random data."""
56109
assert random_dataset.dataobj is not None

test/test_data_dmri.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,23 @@
2222
#
2323
"""Unit tests exercising the dMRI data structure."""
2424

25+
import re
2526
from pathlib import Path
2627

2728
import nibabel as nb
2829
import numpy as np
2930
import pytest
3031

3132
from nifreeze.data import load
32-
from nifreeze.data.dmri import DWI, find_shelling_scheme, from_nii, transform_fsl_bvec
33+
from nifreeze.data.dmri import (
34+
DWI,
35+
GRADIENT_ABSENCE_ERROR_MSG,
36+
GRADIENT_COUNT_MISMATCH_ERROR_MSG,
37+
GRADIENT_SHAPE_ERROR_MSG,
38+
find_shelling_scheme,
39+
from_nii,
40+
transform_fsl_bvec,
41+
)
3342
from nifreeze.utils.ndimage import load_api
3443

3544

@@ -82,6 +91,36 @@ def test_motion_file_not_implemented():
8291
from_nii("dmri.nii.gz", motion_file="motion.x5")
8392

8493

94+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
95+
def test_missing_gradients_error(setup_random_uniform_spatial_data):
96+
data, affine = setup_random_uniform_spatial_data
97+
with pytest.raises(ValueError, match=GRADIENT_ABSENCE_ERROR_MSG):
98+
DWI(dataobj=data, affine=affine)
99+
100+
101+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
102+
def test_gradients_shape_error(setup_random_uniform_spatial_data):
103+
data, affine = setup_random_uniform_spatial_data
104+
gradients = np.zeros((3, data.shape[-1]))
105+
with pytest.raises(ValueError, match=re.escape(GRADIENT_SHAPE_ERROR_MSG)):
106+
DWI(dataobj=data, affine=affine, gradients=gradients)
107+
108+
109+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
110+
def test_gradients_volume_mismatch_error(setup_random_uniform_spatial_data):
111+
data, affine = setup_random_uniform_spatial_data
112+
data_vols = data.shape[-1]
113+
n_gradients = data_vols + 1
114+
gradients = np.zeros((4, n_gradients))
115+
with pytest.raises(
116+
ValueError,
117+
match=re.escape(
118+
GRADIENT_COUNT_MISMATCH_ERROR_MSG.format(n_gradients=n_gradients, data_vols=data_vols)
119+
),
120+
):
121+
DWI(dataobj=data, affine=affine, gradients=gradients)
122+
123+
85124
@pytest.mark.parametrize("insert_b0", (False, True))
86125
@pytest.mark.parametrize("rotate_bvecs", (False, True))
87126
def test_load(datadir, tmp_path, insert_b0, rotate_bvecs): # noqa: C901

0 commit comments

Comments
 (0)