Skip to content

Commit e094509

Browse files
committed
ENH: Validate PET data objects' attributes at instantiation
Validate PET data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities.
1 parent d5fc6e1 commit e094509

File tree

2 files changed

+234
-5
lines changed

2 files changed

+234
-5
lines changed

src/nifreeze/data/pet.py

Lines changed: 114 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,131 @@
3838
from nitransforms.resampling import apply
3939
from typing_extensions import Self
4040

41-
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
41+
from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_ndim
4242
from nifreeze.utils.ndimage import load_api
4343

44+
ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None"
45+
"""PET initialization array attribute absence error message."""
46+
47+
ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a numpy array."
48+
"""PET initialization array attribute object error message."""
49+
50+
ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1-D numpy array."
51+
"""PET initialization array attribute ndim error message."""
52+
53+
SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a scalar."
54+
"""PET initialization scalar attribute shape error message."""
55+
56+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG = (
57+
"PET '{attribute}' length ({attr_len}) does not match number of frames ({data_frames})"
58+
)
59+
"""PET attribute shape mismatch error message."""
60+
61+
62+
def _1d_array_validator(inst: PET, attr: attrs.Attribute, value: Any) -> None:
63+
"""Strict validator to ensure an attribute is a 1-D NumPy array.
64+
65+
This validator enforces that ``value`` is a :obj:`~numpy.ndarray` and that
66+
it has exactly one dimension (``value.ndim == 1``).
67+
68+
This function is intended for use as an attrs-style validator.
69+
70+
Parameters
71+
----------
72+
inst : :obj:`~nifreeze.data.base.PET`
73+
The instance being validated (unused; present for validator signature).
74+
attr : :obj:`attrs.Attribute`
75+
The attribute being validated; ``attr.name`` is used in the error message.
76+
value : :obj:`Any`
77+
The value to validate.
78+
79+
Raises
80+
------
81+
exc:`TypeError`
82+
If the input cannot be converted to a float :obj:`~numpy.ndarray`.
83+
exc:`ValueError`
84+
If the value is ``None``, or not 1-dimensional.
85+
"""
86+
87+
if value is None:
88+
raise ValueError(ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name))
89+
90+
if not isinstance(value, np.ndarray):
91+
raise TypeError(ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name))
92+
93+
if not _has_ndim(value, 1):
94+
raise ValueError(ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr.name))
95+
96+
97+
def _scalar_validator(inst: PET, attr: attrs.Attribute, value: Any) -> None:
98+
"""Strict validator to ensure an attribute is a scalar number.
99+
100+
This validator ensures that ``value`` is a Python integer or floating point
101+
number, or a NumPy scalar numeric type (e.g., :obj:`numpy.integer`,
102+
:obj:`numpy.floating`).
103+
104+
This function is intended for use as an attrs-style validator.
105+
106+
Parameters
107+
----------
108+
inst : :obj:`~nifreeze.data.base.PET`
109+
The instance being validated (unused; present for validator signature).
110+
attr : :obj:`attrs.Attribute`
111+
The attribute being validated; attr.name is used in the error message.
112+
value : :obj:`Any`
113+
The value to validate.
114+
115+
Raises
116+
------
117+
exc:`ValueError`
118+
If ``value`` is not an int/float or a NumPy numeric scalar type.
119+
"""
120+
if not isinstance(value, (int, float, np.integer, np.floating)):
121+
raise ValueError(SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name))
122+
44123

45124
@attrs.define(slots=True)
46125
class PET(BaseDataset[np.ndarray]):
47126
"""Data representation structure for PET data."""
48127

49-
midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
128+
midframe: np.ndarray = attrs.field(
129+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_1d_array_validator
130+
)
50131
"""A (N,) numpy array specifying the midpoint timing of each sample or frame."""
51-
total_duration: float = attrs.field(default=None, repr=True)
132+
total_duration: float = attrs.field(default=None, repr=True, validator=_scalar_validator)
52133
"""A float representing the total duration of the dataset."""
53-
uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
134+
uptake: np.ndarray = attrs.field(
135+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=_1d_array_validator
136+
)
54137
"""A (N,) numpy array specifying the uptake value of each sample or frame."""
55138

139+
def __attrs_post_init__(self) -> None:
140+
"""Enforce presence and basic consistency of required PET fields at
141+
instantiation time.
142+
143+
Specifically, the length of the midframe and uptake attributes must
144+
match the last dimension of the data (number of frames).
145+
"""
146+
data_frames = int(self.dataobj.shape[-1])
147+
148+
if len(self.midframe) != data_frames:
149+
raise ValueError(
150+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
151+
attribute=attrs.fields_dict(self.__class__)["midframe"].name,
152+
attr_len=len(self.midframe),
153+
data_frames=data_frames,
154+
)
155+
)
156+
157+
if len(self.uptake) != data_frames:
158+
raise ValueError(
159+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
160+
attribute=attrs.fields_dict(self.__class__)["uptake"].name,
161+
attr_len=len(self.uptake),
162+
data_frames=data_frames,
163+
)
164+
)
165+
56166
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
57167
return (self.midframe[idx],)
58168

test/test_data_pet.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,25 @@
2222
#
2323

2424
import json
25+
import re
2526
from pathlib import Path
2627

2728
import nibabel as nb
2829
import numpy as np
2930
import pytest
3031
from nitransforms.linear import Affine
3132

32-
from nifreeze.data.pet import PET, _compute_frame_duration, _compute_uptake_statistic, from_nii
33+
from nifreeze.data.pet import (
34+
ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG,
35+
ARRAY_ATTRIBUTE_NDIM_ERROR_MSG,
36+
ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG,
37+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG,
38+
PET,
39+
SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG,
40+
_compute_frame_duration,
41+
_compute_uptake_statistic,
42+
from_nii,
43+
)
3344
from nifreeze.utils.ndimage import load_api
3445

3546

@@ -63,6 +74,114 @@ def random_nifti_file(tmp_path, setup_random_uniform_spatial_data) -> Path:
6374
return _filename
6475

6576

77+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
78+
@pytest.mark.parametrize(
79+
"midframe, expected_exc, expected_msg",
80+
[
81+
(None, ValueError, ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG),
82+
(1, TypeError, ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG),
83+
],
84+
)
85+
def test_pet_attribute_basic_errors(
86+
setup_random_uniform_spatial_data, midframe, expected_exc, expected_msg
87+
):
88+
data, affine = setup_random_uniform_spatial_data
89+
with pytest.raises(expected_exc, match=expected_msg.format(attribute="midframe")):
90+
PET(dataobj=data, affine=affine, midframe=midframe) # type: ignore[arg-type]
91+
92+
93+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
94+
@pytest.mark.parametrize("size", ((2, 1), (1, 2), (3, 1), (3, 2)))
95+
def test_pet_midframe_shape_error(setup_random_uniform_spatial_data, size):
96+
data, affine = setup_random_uniform_spatial_data
97+
midframe = np.zeros(size, dtype=np.float32)
98+
with pytest.raises(
99+
ValueError, match=ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute="midframe")
100+
):
101+
PET(dataobj=data, affine=affine, midframe=midframe)
102+
103+
104+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
105+
@pytest.mark.parametrize("size", ((2, 1), (1, 1)))
106+
def test_pet_total_duration_error(request, setup_random_uniform_spatial_data, size):
107+
data, affine = setup_random_uniform_spatial_data
108+
midframe = np.zeros(data.shape[-1], dtype=np.float32)
109+
total_duration = request.node.rng.uniform(5.0, 20.0, size=size)
110+
with pytest.raises(
111+
ValueError, match=SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute="total_duration")
112+
):
113+
PET(dataobj=data, affine=affine, midframe=midframe, total_duration=total_duration)
114+
115+
116+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
117+
@pytest.mark.parametrize("size", ((2, 1), (1, 2), (3, 1), (3, 2)))
118+
def test_pet_uptake_shape_error(setup_random_uniform_spatial_data, size):
119+
data, affine = setup_random_uniform_spatial_data
120+
midframe = np.zeros(data.shape[-1], dtype=np.float32)
121+
total_duration = 16.2
122+
uptake = np.zeros(size, dtype=np.float32)
123+
with pytest.raises(
124+
ValueError, match=ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute="uptake")
125+
):
126+
PET(
127+
dataobj=data,
128+
affine=affine,
129+
midframe=midframe,
130+
total_duration=total_duration,
131+
uptake=uptake,
132+
)
133+
134+
135+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
136+
def test_pet_midframe_length_mismatch(setup_random_uniform_spatial_data):
137+
data, affine = setup_random_uniform_spatial_data
138+
total_duration = 16.2
139+
data_frames = data.shape[-1]
140+
attr_len = data_frames + 1
141+
midframe = np.zeros(attr_len, dtype=np.float32)
142+
uptake = np.zeros(data_frames, dtype=np.float32)
143+
with pytest.raises(
144+
ValueError,
145+
match=re.escape(
146+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
147+
attribute="midframe", attr_len=attr_len, data_frames=data_frames
148+
)
149+
),
150+
):
151+
PET(
152+
dataobj=data,
153+
affine=affine,
154+
midframe=midframe,
155+
total_duration=total_duration,
156+
uptake=uptake,
157+
)
158+
159+
160+
@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0)
161+
def test_pet_uptake_length_mismatch(setup_random_uniform_spatial_data):
162+
data, affine = setup_random_uniform_spatial_data
163+
total_duration = 16.2
164+
data_frames = data.shape[-1]
165+
midframe = np.zeros(data_frames, dtype=np.float32)
166+
attr_len = data_frames + 1
167+
uptake = np.zeros(attr_len, dtype=np.float32)
168+
with pytest.raises(
169+
ValueError,
170+
match=re.escape(
171+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
172+
attribute="uptake", attr_len=attr_len, data_frames=data_frames
173+
)
174+
),
175+
):
176+
PET(
177+
dataobj=data,
178+
affine=affine,
179+
midframe=midframe,
180+
total_duration=total_duration,
181+
uptake=uptake,
182+
)
183+
184+
66185
@pytest.mark.parametrize(
67186
"midframe, expected",
68187
[

0 commit comments

Comments
 (0)