Skip to content

Commit 96e916b

Browse files
committed
ENH: Validate PET data objects' attributes at instantiation
Validate PET data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities. Refactor the `from_nii` function so that the required parameters are present when instantiating the PET instance. Increase consistency with the `dmri` data module `from_nii` counterpart function. Refactor the tests accordingly and increase consistency with the `dmri` data module testing helper functions. Reduces cognitive load and maintenance burden. Add additional object instantiation equality checks: check that objects intantiated through reading NIfTI files equal objects instantiated directly.
1 parent 3bca8be commit 96e916b

File tree

2 files changed

+352
-20
lines changed

2 files changed

+352
-20
lines changed

src/nifreeze/data/pet.py

Lines changed: 137 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,130 @@
3838
from nitransforms.resampling import apply
3939
from typing_extensions import Self
4040

41-
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
41+
from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_ndim
4242
from nifreeze.utils.ndimage import load_api
4343

44+
ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None"
45+
"""PET initialization array attribute absence error message."""
46+
47+
ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a numpy array."
48+
"""PET initialization array attribute object error message."""
49+
50+
ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array."
51+
"""PET initialization array attribute ndim error message."""
52+
53+
SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a scalar."
54+
"""PET initialization scalar attribute shape error message."""
55+
56+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG = (
57+
"PET '{attribute}' length ({attr_len}) does not match number of frames ({data_frames})"
58+
)
59+
"""PET attribute shape mismatch error message."""
60+
61+
62+
def validate_1d_array(inst: PET, attr: attrs.Attribute, value: Any) -> None:
63+
"""Strict validator to ensure an attribute is a 1D NumPy array.
64+
65+
Enforces that ``value`` is a :obj:`~numpy.ndarray` and that it has exactly
66+
one dimension (``value.ndim == 1``).
67+
68+
This function is intended for use as an attrs-style validator.
69+
70+
Parameters
71+
----------
72+
inst : :obj:`~nifreeze.data.pet.PET`
73+
The instance being validated (unused; present for validator signature).
74+
attr : :obj:`~attrs.Attribute`
75+
The attribute being validated; ``attr.name`` is used in the error message.
76+
value : :obj:`Any`
77+
The value to validate.
78+
79+
Raises
80+
------
81+
exc:`TypeError`
82+
If the input cannot be converted to a float :obj:`~numpy.ndarray`.
83+
exc:`ValueError`
84+
If the value is ``None``, or not 1D.
85+
"""
86+
87+
if value is None:
88+
raise ValueError(ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name))
89+
90+
if not isinstance(value, np.ndarray):
91+
raise TypeError(ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name))
92+
93+
if not _has_ndim(value, 1):
94+
raise ValueError(ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr.name))
95+
96+
97+
def validate_scalar(inst: PET, attr: attrs.Attribute, value: Any) -> None:
98+
"""Strict validator to ensure an attribute is a scalar number.
99+
100+
Ensures that ``value`` is a Python integer or floating point number, or a
101+
NumPy scalar numeric type (e.g., :obj:`numpy.integer`, :obj:`numpy.floating`).
102+
103+
This function is intended for use as an attrs-style validator.
104+
105+
Parameters
106+
----------
107+
inst : :obj:`~nifreeze.data.pet.PET`
108+
The instance being validated (unused; present for validator signature).
109+
attr : :obj:`~attrs.Attribute`
110+
The attribute being validated; attr.name is used in the error message.
111+
value : :obj:`Any`
112+
The value to validate.
113+
114+
Raises
115+
------
116+
exc:`ValueError`
117+
If ``value`` is not an int/float or a NumPy numeric scalar type.
118+
"""
119+
if not isinstance(value, (int, float, np.integer, np.floating)):
120+
raise ValueError(SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name))
121+
44122

45123
@attrs.define(slots=True)
46124
class PET(BaseDataset[np.ndarray]):
47125
"""Data representation structure for PET data."""
48126

49-
midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
127+
midframe: np.ndarray = attrs.field(
128+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array
129+
)
50130
"""A (N,) numpy array specifying the midpoint timing of each sample or frame."""
51-
total_duration: float = attrs.field(default=None, repr=True)
131+
total_duration: float = attrs.field(default=None, repr=True, validator=validate_scalar)
52132
"""A float representing the total duration of the dataset."""
53-
uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
133+
uptake: np.ndarray = attrs.field(
134+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array
135+
)
54136
"""A (N,) numpy array specifying the uptake value of each sample or frame."""
55137

138+
def __attrs_post_init__(self) -> None:
139+
"""Enforce presence and basic consistency of required PET fields at
140+
instantiation time.
141+
142+
Specifically, the length of the midframe and uptake attributes must
143+
match the last dimension of the data (number of frames).
144+
"""
145+
data_frames = int(self.dataobj.shape[-1])
146+
147+
if len(self.midframe) != data_frames:
148+
raise ValueError(
149+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
150+
attribute=attrs.fields_dict(self.__class__)["midframe"].name,
151+
attr_len=len(self.midframe),
152+
data_frames=data_frames,
153+
)
154+
)
155+
156+
if len(self.uptake) != data_frames:
157+
raise ValueError(
158+
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
159+
attribute=attrs.fields_dict(self.__class__)["uptake"].name,
160+
attr_len=len(self.uptake),
161+
data_frames=data_frames,
162+
)
163+
)
164+
56165
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
57166
return (self.midframe[idx],)
58167

@@ -258,37 +367,46 @@ def from_nii(
258367
raise NotImplementedError
259368

260369
filename = Path(filename)
261-
# Load from NIfTI
370+
371+
# 1) Load a NIfTI
262372
img = load_api(filename, SpatialImage)
263-
data = img.get_fdata(dtype=np.float32)
264-
pet_obj = PET(
265-
dataobj=data,
266-
affine=img.affine,
267-
)
373+
fulldata = img.get_fdata(dtype=np.float32)
374+
375+
# 2) Determine uptake value
376+
uptake = _compute_uptake_statistic(fulldata, stat_func=np.sum)
268377

269-
pet_obj.uptake = _compute_uptake_statistic(data, stat_func=np.sum)
378+
# 3) Compute temporal features
270379

271380
# Convert to a float32 numpy array and zero out the earliest time
272381
frame_time_arr = np.array(frame_time, dtype=np.float32)
273382
frame_time_arr -= frame_time_arr[0]
274-
pet_obj.midframe = frame_time_arr
383+
midframe = frame_time_arr
275384

276385
# If the user doesn't provide frame_duration, we derive it:
277386
if frame_duration is None:
278-
durations = _compute_frame_duration(pet_obj.midframe)
387+
durations = _compute_frame_duration(midframe)
279388
else:
280389
durations = np.array(frame_duration, dtype=np.float32)
281390

282-
# Set total_duration and shift frame_time to the midpoint
283-
pet_obj.total_duration = float(frame_time_arr[-1] + durations[-1])
284-
pet_obj.midframe = frame_time_arr + 0.5 * durations
391+
# Compute total_duration and shift midframe to the midpoint
392+
total_duration = float(frame_time_arr[-1] + durations[-1])
393+
midframe = frame_time_arr + 0.5 * durations
285394

286-
# If a brain mask is provided, load and attach
395+
# 4) If a brainmask_file was provided, load it
396+
brainmask_data = None
287397
if brainmask_file is not None:
288398
mask_img = load_api(brainmask_file, SpatialImage)
289-
pet_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool)
399+
brainmask_data = np.asanyarray(mask_img.dataobj, dtype=bool)
290400

291-
return pet_obj
401+
# 5) Create and return the DWI instance.
402+
return PET(
403+
dataobj=fulldata,
404+
affine=img.affine,
405+
brainmask=brainmask_data,
406+
midframe=midframe,
407+
total_duration=total_duration,
408+
uptake=uptake,
409+
)
292410

293411

294412
def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray:

0 commit comments

Comments
 (0)