From 803b020645c6a42891d49ccb329115d33702db4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Thu, 20 Nov 2025 22:00:36 -0500 Subject: [PATCH] ENH: Validate PET data objects' attributes at instantiation Validate PET data objects' attributes at instantiation: ensures that the attributes are present and match the expected dimensionalities. Refactor the PET attributes so that only the required (`frame_time` and `uptake`) and optional (`frame_duration`) parameters are accepted by the constructor. The `midframe` and the `total_duration` attributes can be computed from the required parameters, so exclude it from `__init__`. Although `uptake` can also be computed from the PET frame data, the rationale behind requiring it is similar to the one for the DWI class `bzero`: users will be able compute the `uptake` using their preferred strategy and provide it to the constructor. For the `from_nii` function, if a callable is provided, it will be used to compute the value; otherwise a default strategy is used to compute it. Refactor the `from_nii` function so that the required parameters are present when instantiating the PET instance. Increase consistency with the `dmri` data module `from_nii` counterpart function. Refactor the PET data creation fixture in `conftest.py` to accept the required/optional arguments and to return the necessary data. Refactor the tests accordingly and increase consistency with the `dmri` data module testing helper functions. Reduces cognitive load and maintenance burden. Add additional object instantiation equality checks: check that objects intantiated through reading NIfTI files equal objects instantiated directly. --- src/nifreeze/data/pet.py | 170 +++++++++++++++++++++++------ test/conftest.py | 15 ++- test/test_data_pet.py | 227 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 374 insertions(+), 38 deletions(-) diff --git a/src/nifreeze/data/pet.py b/src/nifreeze/data/pet.py index b3a1f69c6..26d439c16 100644 --- a/src/nifreeze/data/pet.py +++ b/src/nifreeze/data/pet.py @@ -38,20 +38,131 @@ from nitransforms.resampling import apply from typing_extensions import Self -from nifreeze.data.base import BaseDataset, _cmp, _data_repr +from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_ndim from nifreeze.utils.ndimage import load_api +ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None" +"""PET initialization array attribute absence error message.""" + +ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a numpy array." +"""PET initialization array attribute object error message.""" + +ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array." +"""PET initialization array attribute ndim error message.""" + +ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR = """\ +PET '{attribute}' length does not match number of frames: \ +expected {n_frames} values, found {attr_len}.""" +"""PET attribute shape mismatch error message.""" + + +def validate_1d_array(inst: PET, attr: attrs.Attribute, value: Any) -> None: + """Strict validator to ensure an attribute is a 1D NumPy array. + + Enforces that ``value`` is a :obj:`~numpy.ndarray` and that it has exactly + one dimension (``value.ndim == 1``). + + This function is intended for use as an attrs-style validator. + + Parameters + ---------- + inst : :obj:`~nifreeze.data.pet.PET` + The instance being validated (unused; present for validator signature). + attr : :obj:`~attrs.Attribute` + The attribute being validated; ``attr.name`` is used in the error message. + value : :obj:`Any` + The value to validate. + + Raises + ------ + exc:`TypeError` + If the input cannot be converted to a float :obj:`~numpy.ndarray`. + exc:`ValueError` + If the value is ``None``, or not 1D. + """ + + if value is None: + raise ValueError(ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name)) + + if not isinstance(value, np.ndarray): + raise TypeError(ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name)) + + if not _has_ndim(value, 1): + raise ValueError(ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr.name)) + @attrs.define(slots=True) class PET(BaseDataset[np.ndarray]): - """Data representation structure for PET data.""" + """Data representation structure for PET data. + + If not provided, frame duration data are computed as differences between + consecutive midframe times. The last interval is duplicated. + """ - midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)) + frame_time: np.ndarray = attrs.field( + default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array + ) + """A (N,) numpy array specifying the timing of each sample or frame.""" + uptake: np.ndarray = attrs.field( + default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array + ) + """A (N,) numpy array specifying the uptake value of each sample or frame.""" + frame_duration: np.ndarray | None = attrs.field( + default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp) + ) + """A (N,) numpy array specifying the frame duration.""" + midframe: np.ndarray = attrs.field( + default=None, repr=_data_repr, init=False, eq=attrs.cmp_using(eq=_cmp) + ) """A (N,) numpy array specifying the midpoint timing of each sample or frame.""" - total_duration: float = attrs.field(default=None, repr=True) + total_duration: float = attrs.field(default=None, repr=True, init=False) """A float representing the total duration of the dataset.""" - uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)) - """A (N,) numpy array specifying the uptake value of each sample or frame.""" + + def __attrs_post_init__(self) -> None: + """Enforce presence and basic consistency of PET data fields at + instantiation time. + + Specifically, the length of the frame_time and uptake attributes must + match the last dimension of the data (number of frames). + + Computes the values for the private attributes. + """ + n_frames = int(self.dataobj.shape[-1]) + + if len(self.frame_time) != n_frames: + raise ValueError( + ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR.format( + attribute=attrs.fields_dict(self.__class__)["frame_time"].name, + n_frames=n_frames, + attr_len=len(self.frame_time), + ) + ) + + if len(self.uptake) != n_frames: + raise ValueError( + ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR.format( + attribute=attrs.fields_dict(self.__class__)["uptake"].name, + n_frames=n_frames, + attr_len=len(self.uptake), + ) + ) + + # Compute temporal attributes + + # Convert to a float32 numpy array and zero out the earliest time + frame_time_arr = np.array(self.frame_time, dtype=np.float32) + frame_time_arr -= frame_time_arr[0] + self.midframe = frame_time_arr + + # If the user did not provide frame duration values,compute them + if self.frame_duration: + durations = np.array(self.frame_duration, dtype=np.float32) + else: + durations = _compute_frame_duration(self.midframe) + + # Compute total duration and shift midframe to the midpoint + self.total_duration = float(self.midframe[-1] + durations[-1]) + self.midframe = self.midframe + 0.5 * durations def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]: return (self.midframe[idx],) @@ -223,6 +334,7 @@ def from_nii( brainmask_file: Path | str | None = None, motion_file: Path | str | None = None, frame_duration: np.ndarray | list[float] | None = None, + uptake_stat_func: Callable[..., np.ndarray] = np.sum, ) -> PET: """ Load PET data from NIfTI, creating a PET object with appropriate metadata. @@ -242,6 +354,8 @@ def from_nii( The duration of each frame. If ``None``, it is derived by the difference of consecutive frame times, defaulting the last frame to match the second-last. + uptake_stat_func : :obj:`Callable`, optional + The statistic function to be used to compute the uptake value. Returns ------- @@ -258,37 +372,29 @@ def from_nii( raise NotImplementedError filename = Path(filename) - # Load from NIfTI - img = load_api(filename, SpatialImage) - data = img.get_fdata(dtype=np.float32) - pet_obj = PET( - dataobj=data, - affine=img.affine, - ) - - pet_obj.uptake = _compute_uptake_statistic(data, stat_func=np.sum) - # Convert to a float32 numpy array and zero out the earliest time - frame_time_arr = np.array(frame_time, dtype=np.float32) - frame_time_arr -= frame_time_arr[0] - pet_obj.midframe = frame_time_arr - - # If the user doesn't provide frame_duration, we derive it: - if frame_duration is None: - durations = _compute_frame_duration(pet_obj.midframe) - else: - durations = np.array(frame_duration, dtype=np.float32) + # 1) Load a NIfTI + img = load_api(filename, SpatialImage) + fulldata = img.get_fdata(dtype=np.float32) - # Set total_duration and shift frame_time to the midpoint - pet_obj.total_duration = float(frame_time_arr[-1] + durations[-1]) - pet_obj.midframe = frame_time_arr + 0.5 * durations + # 2) Determine uptake value + uptake = _compute_uptake_statistic(fulldata, stat_func=uptake_stat_func) - # If a brain mask is provided, load and attach + # 3) If a brainmask_file was provided, load it + brainmask_data = None if brainmask_file is not None: mask_img = load_api(brainmask_file, SpatialImage) - pet_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool) + brainmask_data = np.asanyarray(mask_img.dataobj, dtype=bool) - return pet_obj + # 4) Create and return the PET instance + return PET( + dataobj=fulldata, + affine=img.affine, + brainmask=brainmask_data, + frame_time=np.asarray(frame_time), + frame_duration=np.asarray(frame_duration), + uptake=uptake, + ) def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray: @@ -313,7 +419,7 @@ def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray: return durations -def _compute_uptake_statistic(data: np.ndarray, stat_func: Callable = np.sum): +def _compute_uptake_statistic(data: np.ndarray, stat_func: Callable[..., np.ndarray] = np.sum): """Compute a statistic over all voxels for each frame on a PET sequence. Assumes the last dimension corresponds to the number of frames in the diff --git a/test/conftest.py b/test/conftest.py index 4b9a091ce..b97d041d0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -323,22 +323,27 @@ def setup_random_pet_data(request): n_frames = 5 vol_size = (4, 4, 4) - midframe = np.arange(n_frames, dtype=np.float32) + 1 - total_duration = float(n_frames + 1) + uptake_stat_func = np.sum + frame_duration = None if marker: - n_frames, vol_size, midframe, total_duration = marker.args + n_frames, vol_size, uptake_stat_func, frame_duration = marker.args rng = request.node.rng + frame_time = np.arange(n_frames, dtype=np.float32) + 1 + pet_dataobj, affine = _generate_random_uniform_spatial_data( request, (*vol_size, n_frames), 0.0, 1.0 ) brainmask_dataobj = rng.choice([True, False], size=vol_size).astype(bool) + uptake = uptake_stat_func(pet_dataobj.reshape(-1, pet_dataobj.shape[-1]), axis=0) + return ( pet_dataobj, affine, brainmask_dataobj, - midframe, - total_duration, + frame_time, + uptake, + frame_duration, ) diff --git a/test/test_data_pet.py b/test/test_data_pet.py index f03861a29..462e2911c 100644 --- a/test/test_data_pet.py +++ b/test/test_data_pet.py @@ -29,10 +29,36 @@ import pytest from nitransforms.linear import Affine -from nifreeze.data.pet import PET, _compute_frame_duration, _compute_uptake_statistic, from_nii +from nifreeze.data.pet import ( + ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG, + ARRAY_ATTRIBUTE_NDIM_ERROR_MSG, + ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG, + ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR, + PET, + _compute_frame_duration, + _compute_uptake_statistic, + from_nii, +) from nifreeze.utils.ndimage import load_api +def _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj): + pet = nb.Nifti1Image(pet_dataobj, affine) + brainmask = nb.Nifti1Image(brainmask_dataobj, affine) + + return pet, brainmask + + +def _serialize_pet_data(pet, brainmask, _tmp_path): + pet_fname = _tmp_path / "pet.nii.gz" + brainmask_fname = _tmp_path / "brainmask.nii.gz" + + nb.save(pet, pet_fname) + nb.save(brainmask, brainmask_fname) + + return pet_fname, brainmask_fname + + @pytest.fixture def random_dataset(setup_random_pet_data) -> PET: """Create a PET dataset with random data for testing.""" @@ -63,6 +89,127 @@ def random_nifti_file(tmp_path, setup_random_uniform_spatial_data) -> Path: return _filename +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +@pytest.mark.parametrize( + "attribute_name, value, expected_exc, expected_msg", + [ + ("frame_time", None, ValueError, ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG), + ("frame_time", 1, TypeError, ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG), + ("uptake", None, ValueError, ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG), + ("uptake", 3.0, TypeError, ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG), + ], +) +def test_pet_instantiation_attribute_basic_errors( + setup_random_uniform_spatial_data, attribute_name, value, expected_exc, expected_msg +): + data, affine = setup_random_uniform_spatial_data + + if attribute_name == "frame_time": + frame_time = value + uptake = np.zeros(data.shape[-1], dtype=np.float32) + + with pytest.raises(expected_exc, match=expected_msg.format(attribute="frame_time")): + PET(dataobj=data, affine=affine, frame_time=frame_time, uptake=uptake) # type: ignore[arg-type] + + elif attribute_name == "uptake": + uptake = value + frame_time = np.zeros(data.shape[-1], dtype=np.float32) + + with pytest.raises(expected_exc, match=expected_msg.format(attribute="uptake")): + PET(dataobj=data, affine=affine, frame_time=frame_time, uptake=uptake) # type: ignore[arg-type] + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.sum, None) +@pytest.mark.parametrize("attribute_name", ("frame_time", "uptake")) +@pytest.mark.parametrize("additional_dimensions", (1, 2)) +@pytest.mark.parametrize("transpose", (True, False)) +def test_pet_instantiation_attribute_ndim_errors( + request, setup_random_pet_data, attribute_name, additional_dimensions, transpose +): + rng = request.node.rng + ( + pet_dataobj, + affine, + _, + frame_time, + uptake, + _, + ) = setup_random_pet_data + + if attribute_name == "frame_time": + frame_time = np.concatenate( + [frame_time[:, None], rng.random((frame_time.size, additional_dimensions))], axis=1 + ) + frame_time = frame_time.T if transpose else frame_time + with pytest.raises( + ValueError, match=ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute="frame_time") + ): + PET(dataobj=pet_dataobj, affine=affine, frame_time=frame_time, uptake=uptake) # type: ignore[arg-type] + + elif attribute_name == "uptake": + uptake = np.concatenate( + [uptake[:, None], rng.random((uptake.size, additional_dimensions))], axis=1 + ) + uptake = uptake.T if transpose else uptake + with pytest.raises( + ValueError, match=ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute="uptake") + ): + PET(dataobj=pet_dataobj, affine=affine, frame_time=frame_time, uptake=uptake) # type: ignore[arg-type] + + +@pytest.mark.random_pet_data(4, (2, 2, 2), np.sum, None) +@pytest.mark.parametrize("attribute_name", ("frame_time", "uptake")) +@pytest.mark.parametrize( + ("additional_volume_count", "additional_attribute_count"), + [(1, 0), (2, 0), (2, 1), (0, 1), (0, 2), (1, 2)], +) +def test_pet_instantiation_attribute_vol_mismatch_error( + setup_random_pet_data, attribute_name, additional_volume_count, additional_attribute_count +): + ( + pet_dataobj, + affine, + _, + frame_time, + uptake, + _, + ) = setup_random_pet_data + + n_frames = int(pet_dataobj.shape[-1]) + + # Add additional volumes: simply concatenate the last volume + if additional_volume_count: + additional_dwi_dataobj = np.tile(pet_dataobj[..., -1:], (1, additional_volume_count)) + pet_dataobj = np.concatenate((pet_dataobj, additional_dwi_dataobj), axis=-1) + n_frames = int(pet_dataobj.shape[-1]) + # Add additional values to attribute: simply concatenate the attribute + if additional_attribute_count: + if attribute_name == "frame_time": + additional_frame_time = np.repeat(frame_time[-1], additional_attribute_count) + frame_time = np.concatenate((frame_time, additional_frame_time)) + elif attribute_name == "uptake": + additional_uptake = np.repeat(uptake[-1], additional_attribute_count) + uptake = np.concatenate((uptake, additional_uptake)) + + if attribute_name == "frame_time" or additional_volume_count != 0: + with pytest.raises( + ValueError, + match=ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR.format( + attribute="frame_time", n_frames=n_frames, attr_len=len(frame_time) + ), + ): + PET(dataobj=pet_dataobj, affine=affine, frame_time=frame_time, uptake=uptake) # type: ignore[arg-type] + + elif attribute_name == "uptake": + with pytest.raises( + ValueError, + match=ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR.format( + attribute="uptake", n_frames=n_frames, attr_len=len(uptake) + ), + ): + PET(dataobj=pet_dataobj, affine=affine, frame_time=frame_time, uptake=uptake) # type: ignore[arg-type] + + @pytest.mark.parametrize( "midframe, expected", [ @@ -175,6 +322,84 @@ def test_round_trip(tmp_path, random_nifti_file, frame_time, frame_duration): assert units[0] == "mm" +def test_equality_operator(tmp_path, setup_random_pet_data): + ( + pet_dataobj, + affine, + brainmask_dataobj, + midframe, + total_duration, + ) = setup_random_pet_data + + pet, brainmask = _pet_data_to_nifti( + pet_dataobj, + affine, + brainmask_dataobj.astype(np.uint8), + ) + + ( + pet_fname, + brainmask_fname, + ) = _serialize_pet_data( + pet, + brainmask, + tmp_path, + ) + + # ToDo + # All this needs to be done automatically: the from_nii and PET data + # instantiation need to be refactored + + start_time = 0.0 + mid = np.asarray(midframe, dtype=np.float32) + frame_time = (mid + np.float32(start_time)).astype(np.float32) + frame_duration = _compute_frame_duration(midframe) + + # Read back using public API + pet_obj_from_nii = from_nii( + pet_fname, + frame_time=frame_time, + brainmask_file=brainmask_fname, + frame_duration=frame_duration, + ) + + # Direct instantiation with the same arrays + uptake = _compute_uptake_statistic(pet_dataobj, stat_func=np.sum) + pet_obj_direct = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + uptake=uptake, + ) + + # Sanity checks (element-wise) + assert np.allclose(pet_obj_direct.dataobj, pet_obj_from_nii.dataobj) + assert np.allclose(pet_obj_direct.affine, pet_obj_from_nii.affine) + if pet_obj_direct.brainmask is None or pet_obj_from_nii.brainmask is None: + assert pet_obj_direct.brainmask is None + assert pet_obj_from_nii.brainmask is None + else: + assert np.array_equal(pet_obj_direct.brainmask, pet_obj_from_nii.brainmask) + assert np.allclose(pet_obj_direct.midframe, pet_obj_from_nii.midframe) + assert np.allclose(pet_obj_direct.total_duration, pet_obj_from_nii.total_duration) + assert np.allclose(pet_obj_direct.uptake, pet_obj_from_nii.uptake) + + # Test equality operator + assert pet_obj_direct == pet_obj_from_nii + + # Test equality operator against an instance from HDF5 + hdf5_filename = tmp_path / "test_pet.h5" + pet_obj_from_nii.to_filename(hdf5_filename) + + round_trip_pet_obj = PET.from_filename(hdf5_filename) + + # Symmetric equality + assert pet_obj_from_nii == round_trip_pet_obj + assert round_trip_pet_obj == pet_obj_from_nii + + @pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) def test_pet_set_transform_updates_motion_affines(random_dataset): idx = 2