Skip to content

Commit fbee186

Browse files
committed
ENH: Validate data objects' attributes at instantiation
Validate data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities.
1 parent 997df63 commit fbee186

File tree

6 files changed

+772
-58
lines changed

6 files changed

+772
-58
lines changed

src/nifreeze/data/base.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,21 @@
4444

4545
ImageGrid = namedtuple("ImageGrid", ("shape", "affine"))
4646

47+
DATAOBJ_ABSENCE_ERROR_MSG = "BaseDataset 'dataobj' may not be None"
48+
"""BaseDataset initialization dataobj absence error message."""
49+
50+
DATAOBJ_NDIM_ERROR_MSG = "BaseDataset 'dataobj' must be a 4-D numpy array"
51+
"""BaseDataset initialization dataobj dimensionality error message."""
52+
53+
AFFINE_ABSENCE_ERROR_MSG = "BaseDataset 'affine' may not be None"
54+
"""BaseDataset initialization affine absence error message."""
55+
56+
AFFINE_SHAPE_ERROR_MSG = "BaseDataset 'affine' must be a 2-D numpy array (4 x 4)"
57+
"""BaseDataset initialization affine shape error message."""
58+
59+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG = "BaseDataset brainmask shape ({brainmask_shape}) does not match dataset volumes ({data_shape})."
60+
"""BaseDataset brainmask shape mismatch error message."""
61+
4762

4863
def _data_repr(value: np.ndarray | None) -> str:
4964
if value is None:
@@ -58,6 +73,20 @@ def _cmp(lh: Any, rh: Any) -> bool:
5873
return lh == rh
5974

6075

76+
def _dataobj_validator(inst, attr, value) -> None:
77+
if value is None:
78+
raise ValueError(DATAOBJ_ABSENCE_ERROR_MSG)
79+
if not isinstance(value, np.ndarray) or value.ndim != 4:
80+
raise ValueError(DATAOBJ_NDIM_ERROR_MSG)
81+
82+
83+
def _affine_validator(inst, attr, value) -> None:
84+
if value is None:
85+
raise ValueError(AFFINE_ABSENCE_ERROR_MSG)
86+
if not isinstance(value, np.ndarray) or value.shape != (4, 4):
87+
raise ValueError(AFFINE_SHAPE_ERROR_MSG)
88+
89+
6190
@attrs.define(slots=True)
6291
class BaseDataset(Generic[Unpack[Ts]]):
6392
"""
@@ -75,9 +104,13 @@ class BaseDataset(Generic[Unpack[Ts]]):
75104
76105
"""
77106

78-
dataobj: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
107+
dataobj: np.ndarray = attrs.field(
108+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_dataobj_validator
109+
)
79110
"""A :obj:`~numpy.ndarray` object for the data array."""
80-
affine: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
111+
affine: np.ndarray = attrs.field(
112+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_affine_validator
113+
)
81114
"""Best affine for RAS-to-voxel conversion of coordinates (NIfTI header)."""
82115
brainmask: np.ndarray | None = attrs.field(
83116
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
@@ -95,6 +128,22 @@ class BaseDataset(Generic[Unpack[Ts]]):
95128
)
96129
"""A path to an HDF5 file to store the whole dataset."""
97130

131+
def __attrs_post_init__(self) -> None:
132+
"""Enforce basic consistency of base dataset fields at instantiation
133+
time.
134+
135+
Specifically, the brainmask (if present) must match spatial shape of
136+
dataobj.
137+
"""
138+
139+
if self.brainmask is not None:
140+
if self.brainmask.shape != tuple(self.dataobj.shape[:3]):
141+
raise ValueError(
142+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG.format(
143+
brainmask_shape=self.brainmask.shape, data_shape=self.dataobj.shape[:3]
144+
)
145+
)
146+
98147
def __len__(self) -> int:
99148
"""Obtain the number of volumes/frames in the dataset."""
100149
return self.dataobj.shape[-1]

0 commit comments

Comments
 (0)