Skip to content

Commit 22528df

Browse files
committed
ENH: Validate data objects' attributes at instantiation
Validate data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities.
1 parent 6723229 commit 22528df

File tree

6 files changed

+925
-60
lines changed

6 files changed

+925
-60
lines changed

src/nifreeze/data/base.py

Lines changed: 211 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,125 @@
4444

4545
ImageGrid = namedtuple("ImageGrid", ("shape", "affine"))
4646

47+
DATAOBJ_ABSENCE_ERROR_MSG = "BaseDataset 'dataobj' may not be None"
48+
"""BaseDataset initialization dataobj absence error message."""
49+
50+
DATAOBJ_OBJECT_ERROR_MSG = "BaseDataset 'dataobj' must be a numpy array."
51+
"""BaseDataset initialization dataobj object error message."""
52+
53+
DATAOBJ_NDIM_ERROR_MSG = "BaseDataset 'dataobj' must be a 4-D numpy array"
54+
"""BaseDataset initialization dataobj dimensionality error message."""
55+
56+
AFFINE_ABSENCE_ERROR_MSG = "BaseDataset 'affine' may not be None"
57+
"""BaseDataset initialization affine absence error message."""
58+
59+
AFFINE_OBJECT_ERROR_MSG = "BaseDataset 'affine' must be a numpy array."
60+
"""BaseDataset initialization affine object error message."""
61+
62+
AFFINE_NDIM_ERROR_MSG = "BaseDataset 'affine' must be a 2D array"
63+
"""Affine dimensionality error message."""
64+
65+
AFFINE_SHAPE_ERROR_MSG = "BaseDataset 'affine' must be a 2-D numpy array (4 x 4)"
66+
"""BaseDataset initialization affine shape error message."""
67+
68+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG = "BaseDataset 'brainmask' shape ({brainmask_shape}) does not match dataset volumes ({data_shape})."
69+
"""BaseDataset brainmask shape mismatch error message."""
70+
71+
72+
def _has_dim_size(value: Any, size: int) -> bool:
73+
"""Return ``True`` if ``value`` has a ``.shape`` attribute and one of its
74+
dimensions equals ``size``.
75+
76+
This is useful for checks where at least one axis must match an expected
77+
length. It does not require a specific axis index; it only verifies presence
78+
of the size in any axis in ``.shape``.
79+
80+
Parameters
81+
----------
82+
value : :obj:`Any`
83+
Object to inspect. Typical inputs are NumPy arrays or objects exposing
84+
``.shape``.
85+
size : :obj:`int`
86+
The required dimension size to look for in ``value.shape``.
87+
88+
Returns
89+
-------
90+
:obj:`bool`
91+
``True`` if ``.shape`` exists and any of its integers equals ``size``,
92+
``False`` otherwise.
93+
94+
Examples
95+
--------
96+
>>> _has_dim_size(np.zeros((10, 3)), 3)
97+
True
98+
>>> _has_dim_size(np.zeros((4, 5)), 6)
99+
False
100+
"""
101+
102+
shape = getattr(value, "shape", None)
103+
if shape is None:
104+
return False
105+
# Shape may be an object that is not iterable; handle TypeError explicitly
106+
try:
107+
return size in tuple(shape)
108+
except TypeError:
109+
return False
110+
111+
112+
def _has_ndim(value: Any, ndim: int) -> bool:
113+
"""Check if ``value`` has ``ndim`` dimensionality.
114+
115+
Returns ``True`` if `value` has an ``.ndim`` attribute equal to ``ndim``, or
116+
if it has a ``.shape`` attribute whose length equals ``ndim``.
117+
118+
This helper is tolerant: it accepts objects that either:
119+
- expose an integer ``.ndim`` attribute (e.g., NumPy arrays), or
120+
- expose a ``.shape`` attribute (sequence/tuple-like) whose length equals
121+
``ndim``.
122+
123+
Parameters
124+
----------
125+
value : :obj:`Any`
126+
Object to inspect for dimensionality. Typical inputs are NumPy arrays,
127+
array-likes, or objects that provide ``.ndim`` / ``.shape``.
128+
ndim : :obj:`int`
129+
The required dimensionality.
130+
131+
Returns
132+
-------
133+
:obj:`bool`
134+
``True`` if ``value`` appears to have ``ndim`` dimensions, ``False``
135+
otherwise.
136+
137+
Examples
138+
--------
139+
>>> _has_ndim(np.zeros((2, 3)), 2)
140+
True
141+
>>> _has_ndim(np.zeros((3,)), 2)
142+
False
143+
>>> class WithShape:
144+
... shape = (2, 2, 2)
145+
>>> _has_ndim(WithShape(), 3)
146+
True
147+
"""
148+
149+
# Prefer .ndim if available
150+
ndim_attr = getattr(value, "ndim", None)
151+
if ndim_attr is not None:
152+
try:
153+
return int(ndim_attr) == ndim
154+
except (TypeError, ValueError):
155+
return False
156+
157+
# Fallback to checking shape length
158+
shape = getattr(value, "shape", None)
159+
if shape is None:
160+
return False
161+
try:
162+
return len(tuple(shape)) == ndim
163+
except TypeError:
164+
return False
165+
47166

48167
def _data_repr(value: np.ndarray | None) -> str:
49168
if value is None:
@@ -58,6 +177,76 @@ def _cmp(lh: Any, rh: Any) -> bool:
58177
return lh == rh
59178

60179

180+
def _dataobj_validator(inst, attr, value) -> None:
181+
"""Strict validator for data objects.
182+
183+
It enforces that ``value`` is present and is a NumPy array with exactly 4
184+
dimensions (``ndim == 4``).
185+
186+
This function is intended for use as an attrs-style validator.
187+
188+
Parameters
189+
----------
190+
inst : :obj:`Any`
191+
The instance being validated (unused, present for validator signature).
192+
attr : :obj:`Any`
193+
The attribute being validated (unused, present for validator signature).
194+
value : :obj:`Any`
195+
The value to validate.
196+
197+
Raises
198+
------
199+
exc:`TypeError`
200+
If the input cannot be converted to a float :obj:`~numpy.ndarray`.
201+
exc:`ValueError`
202+
If the value is ``None``, or not 4-dimensional.
203+
"""
204+
if value is None:
205+
raise ValueError(DATAOBJ_ABSENCE_ERROR_MSG)
206+
207+
if not isinstance(value, np.ndarray):
208+
raise TypeError(DATAOBJ_OBJECT_ERROR_MSG)
209+
210+
if not _has_ndim(value, 4):
211+
raise ValueError(DATAOBJ_NDIM_ERROR_MSG)
212+
213+
214+
def _affine_validator(inst, attr, value) -> None:
215+
"""Strict validator for affine matrices.
216+
217+
It enforces that ``value`` is present and is a 4x4 NumPy array.
218+
219+
This function is intended for use as an attrs-style validator.
220+
221+
Parameters
222+
----------
223+
inst : :obj:`Any`
224+
The instance being validated (unused, present for validator signature).
225+
attr : :obj:`Any`
226+
The attribute being validated (unused, present for validator signature).
227+
value : :obj:`Any`
228+
The value to validate.
229+
230+
Raises
231+
------
232+
exc:`TypeError`
233+
If the input cannot be converted to a float :obj:`~numpy.ndarray`.
234+
exc:`ValueError`
235+
If the value is ``None``, or not shaped ``(4, 4)``.
236+
"""
237+
if value is None:
238+
raise ValueError(AFFINE_ABSENCE_ERROR_MSG)
239+
240+
if not isinstance(value, np.ndarray):
241+
raise TypeError(AFFINE_OBJECT_ERROR_MSG)
242+
243+
if not _has_ndim(value, 2):
244+
raise ValueError(AFFINE_NDIM_ERROR_MSG)
245+
246+
if value.shape != (4, 4):
247+
raise ValueError(AFFINE_SHAPE_ERROR_MSG)
248+
249+
61250
@attrs.define(slots=True)
62251
class BaseDataset(Generic[Unpack[Ts]]):
63252
"""
@@ -75,9 +264,13 @@ class BaseDataset(Generic[Unpack[Ts]]):
75264
76265
"""
77266

78-
dataobj: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
267+
dataobj: np.ndarray = attrs.field(
268+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_dataobj_validator
269+
)
79270
"""A :obj:`~numpy.ndarray` object for the data array."""
80-
affine: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
271+
affine: np.ndarray = attrs.field(
272+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_affine_validator
273+
)
81274
"""Best affine for RAS-to-voxel conversion of coordinates (NIfTI header)."""
82275
brainmask: np.ndarray | None = attrs.field(
83276
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
@@ -95,6 +288,22 @@ class BaseDataset(Generic[Unpack[Ts]]):
95288
)
96289
"""A path to an HDF5 file to store the whole dataset."""
97290

291+
def __attrs_post_init__(self) -> None:
292+
"""Enforce basic consistency of base dataset fields at instantiation
293+
time.
294+
295+
Specifically, the brainmask (if present) must match spatial shape of
296+
dataobj.
297+
"""
298+
299+
if self.brainmask is not None:
300+
if self.brainmask.shape != tuple(self.dataobj.shape[:3]):
301+
raise ValueError(
302+
BRAINMASK_SHAPE_MISMATCH_ERROR_MSG.format(
303+
brainmask_shape=self.brainmask.shape, data_shape=self.dataobj.shape[:3]
304+
)
305+
)
306+
98307
def __len__(self) -> int:
99308
"""Obtain the number of volumes/frames in the dataset."""
100309
return self.dataobj.shape[-1]

0 commit comments

Comments
 (0)