Skip to content

Commit 7dc1033

Browse files
committed
ENH: Validate PET data objects' attributes at instantiation
Validate PET data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities. Refactor the PET attributes so that only the required (`frame_time` and `uptake`) and optional (`frame_duration`) parameters are accepted by the constructor. The `midframe` and the `total_duration` attributes can be computed from the required parameters, so exclude it from `__init__`. Although `uptake` can also be computed from the PET frame data, the rationale behind requiring it is similar to the one for the DWI class `bzero`: users will be able compute the `uptake` using their preferred strategy and provide it to the constructor. For the `from_nii` function, if a callable is provided, it will be used to compute the value; otherwise a default strategy is used to compute it. Refactor the `from_nii` function so that the required parameters are present when instantiating the PET instance. Increase consistency with the `dmri` data module `from_nii` counterpart function. Refactor the PET data creation fixture in `conftest.py` to accept the required/optional arguments and to return the necessary data. Refactor the tests accordingly and increase consistency with the `dmri` data module testing helper functions. Reduces cognitive load and maintenance burden. Add additional object instantiation equality checks: check that objects intantiated through reading NIfTI files equal objects instantiated directly.
1 parent 3bca8be commit 7dc1033

File tree

3 files changed

+373
-37
lines changed

3 files changed

+373
-37
lines changed

src/nifreeze/data/pet.py

Lines changed: 137 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,20 +38,131 @@
3838
from nitransforms.resampling import apply
3939
from typing_extensions import Self
4040

41-
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
41+
from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_ndim
4242
from nifreeze.utils.ndimage import load_api
4343

44+
ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None"
45+
"""PET initialization array attribute absence error message."""
46+
47+
ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a numpy array."
48+
"""PET initialization array attribute object error message."""
49+
50+
ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array."
51+
"""PET initialization array attribute ndim error message."""
52+
53+
ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR = """\
54+
PET '{attribute}' length does not match number of frames: \
55+
expected {n_frames} values, found {attr_len}."""
56+
"""PET attribute shape mismatch error message."""
57+
58+
59+
def validate_1d_array(inst: PET, attr: attrs.Attribute, value: Any) -> None:
60+
"""Strict validator to ensure an attribute is a 1D NumPy array.
61+
62+
Enforces that ``value`` is a :obj:`~numpy.ndarray` and that it has exactly
63+
one dimension (``value.ndim == 1``).
64+
65+
This function is intended for use as an attrs-style validator.
66+
67+
Parameters
68+
----------
69+
inst : :obj:`~nifreeze.data.pet.PET`
70+
The instance being validated (unused; present for validator signature).
71+
attr : :obj:`~attrs.Attribute`
72+
The attribute being validated; ``attr.name`` is used in the error message.
73+
value : :obj:`Any`
74+
The value to validate.
75+
76+
Raises
77+
------
78+
exc:`TypeError`
79+
If the input cannot be converted to a float :obj:`~numpy.ndarray`.
80+
exc:`ValueError`
81+
If the value is ``None``, or not 1D.
82+
"""
83+
84+
if value is None:
85+
raise ValueError(ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name))
86+
87+
if not isinstance(value, np.ndarray):
88+
raise TypeError(ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name))
89+
90+
if not _has_ndim(value, 1):
91+
raise ValueError(ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr.name))
92+
4493

4594
@attrs.define(slots=True)
4695
class PET(BaseDataset[np.ndarray]):
47-
"""Data representation structure for PET data."""
96+
"""Data representation structure for PET data.
97+
98+
If not provided, frame duration data are computed as differences between
99+
consecutive midframe times. The last interval is duplicated.
100+
"""
48101

49-
midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
102+
frame_time: np.ndarray = attrs.field(
103+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array
104+
)
105+
"""A (N,) numpy array specifying the timing of each sample or frame."""
106+
uptake: np.ndarray = attrs.field(
107+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array
108+
)
109+
"""A (N,) numpy array specifying the uptake value of each sample or frame."""
110+
frame_duration: np.ndarray | None = attrs.field(
111+
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)
112+
)
113+
"""A (N,) numpy array specifying the frame duration."""
114+
midframe: np.ndarray = attrs.field(
115+
default=None, repr=_data_repr, init=False, eq=attrs.cmp_using(eq=_cmp)
116+
)
50117
"""A (N,) numpy array specifying the midpoint timing of each sample or frame."""
51-
total_duration: float = attrs.field(default=None, repr=True)
118+
total_duration: float = attrs.field(default=None, repr=True, init=False)
52119
"""A float representing the total duration of the dataset."""
53-
uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
54-
"""A (N,) numpy array specifying the uptake value of each sample or frame."""
120+
121+
def __attrs_post_init__(self) -> None:
122+
"""Enforce presence and basic consistency of PET data fields at
123+
instantiation time.
124+
125+
Specifically, the length of the frame_time and uptake attributes must
126+
match the last dimension of the data (number of frames).
127+
128+
Computes the values for the private attributes.
129+
"""
130+
n_frames = int(self.dataobj.shape[-1])
131+
132+
if len(self.frame_time) != n_frames:
133+
raise ValueError(
134+
ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR.format(
135+
attribute=attrs.fields_dict(self.__class__)["frame_time"].name,
136+
n_frames=n_frames,
137+
attr_len=len(self.frame_time),
138+
)
139+
)
140+
141+
if len(self.uptake) != n_frames:
142+
raise ValueError(
143+
ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR.format(
144+
attribute=attrs.fields_dict(self.__class__)["uptake"].name,
145+
n_frames=n_frames,
146+
attr_len=len(self.uptake),
147+
)
148+
)
149+
150+
# Compute temporal attributes
151+
152+
# Convert to a float32 numpy array and zero out the earliest time
153+
frame_time_arr = np.array(self.frame_time, dtype=np.float32)
154+
frame_time_arr -= frame_time_arr[0]
155+
self.midframe = frame_time_arr
156+
157+
# If the user did not provide frame duration values,compute them
158+
if self.frame_duration:
159+
durations = np.array(self.frame_duration, dtype=np.float32)
160+
else:
161+
durations = _compute_frame_duration(self.midframe)
162+
163+
# Compute total duration and shift midframe to the midpoint
164+
self.total_duration = float(self.midframe[-1] + durations[-1])
165+
self.midframe = self.midframe + 0.5 * durations
55166

56167
def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
57168
return (self.midframe[idx],)
@@ -223,6 +334,7 @@ def from_nii(
223334
brainmask_file: Path | str | None = None,
224335
motion_file: Path | str | None = None,
225336
frame_duration: np.ndarray | list[float] | None = None,
337+
uptake_stat_func: Callable[..., np.ndarray] | None = np.sum,
226338
) -> PET:
227339
"""
228340
Load PET data from NIfTI, creating a PET object with appropriate metadata.
@@ -242,6 +354,8 @@ def from_nii(
242354
The duration of each frame.
243355
If ``None``, it is derived by the difference of consecutive frame times,
244356
defaulting the last frame to match the second-last.
357+
uptake_stat_func : :obj:`Callable`, optional
358+
The statistic function to be used to compute the uptake value.
245359
246360
Returns
247361
-------
@@ -258,37 +372,29 @@ def from_nii(
258372
raise NotImplementedError
259373

260374
filename = Path(filename)
261-
# Load from NIfTI
262-
img = load_api(filename, SpatialImage)
263-
data = img.get_fdata(dtype=np.float32)
264-
pet_obj = PET(
265-
dataobj=data,
266-
affine=img.affine,
267-
)
268-
269-
pet_obj.uptake = _compute_uptake_statistic(data, stat_func=np.sum)
270375

271-
# Convert to a float32 numpy array and zero out the earliest time
272-
frame_time_arr = np.array(frame_time, dtype=np.float32)
273-
frame_time_arr -= frame_time_arr[0]
274-
pet_obj.midframe = frame_time_arr
275-
276-
# If the user doesn't provide frame_duration, we derive it:
277-
if frame_duration is None:
278-
durations = _compute_frame_duration(pet_obj.midframe)
279-
else:
280-
durations = np.array(frame_duration, dtype=np.float32)
376+
# 1) Load a NIfTI
377+
img = load_api(filename, SpatialImage)
378+
fulldata = img.get_fdata(dtype=np.float32)
281379

282-
# Set total_duration and shift frame_time to the midpoint
283-
pet_obj.total_duration = float(frame_time_arr[-1] + durations[-1])
284-
pet_obj.midframe = frame_time_arr + 0.5 * durations
380+
# 2) Determine uptake value
381+
uptake = _compute_uptake_statistic(fulldata, stat_func=uptake_stat_func)
285382

286-
# If a brain mask is provided, load and attach
383+
# 3) If a brainmask_file was provided, load it
384+
brainmask_data = None
287385
if brainmask_file is not None:
288386
mask_img = load_api(brainmask_file, SpatialImage)
289-
pet_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool)
387+
brainmask_data = np.asanyarray(mask_img.dataobj, dtype=bool)
290388

291-
return pet_obj
389+
# 4) Create and return the PET instance
390+
return PET(
391+
dataobj=fulldata,
392+
affine=img.affine,
393+
brainmask=brainmask_data,
394+
frame_time=frame_time,
395+
frame_duration=frame_duration,
396+
uptake=uptake,
397+
)
292398

293399

294400
def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray:

test/conftest.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,22 +323,27 @@ def setup_random_pet_data(request):
323323

324324
n_frames = 5
325325
vol_size = (4, 4, 4)
326-
midframe = np.arange(n_frames, dtype=np.float32) + 1
327-
total_duration = float(n_frames + 1)
326+
uptake_stat_func = np.sum
327+
frame_duration = None
328328
if marker:
329-
n_frames, vol_size, midframe, total_duration = marker.args
329+
n_frames, vol_size, uptake_stat_func, frame_duration = marker.args
330330

331331
rng = request.node.rng
332332

333+
frame_time = np.arange(n_frames, dtype=np.float32) + 1
334+
333335
pet_dataobj, affine = _generate_random_uniform_spatial_data(
334336
request, (*vol_size, n_frames), 0.0, 1.0
335337
)
336338
brainmask_dataobj = rng.choice([True, False], size=vol_size).astype(bool)
337339

340+
uptake = uptake_stat_func(pet_dataobj.reshape(-1, pet_dataobj.shape[-1]), axis=0)
341+
338342
return (
339343
pet_dataobj,
340344
affine,
341345
brainmask_dataobj,
342-
midframe,
343-
total_duration,
346+
frame_time,
347+
uptake,
348+
frame_duration,
344349
)

0 commit comments

Comments
 (0)