diff --git a/src/nifreeze/data/pet.py b/src/nifreeze/data/pet.py index b3a1f69c6..d40495eab 100644 --- a/src/nifreeze/data/pet.py +++ b/src/nifreeze/data/pet.py @@ -38,21 +38,130 @@ from nitransforms.resampling import apply from typing_extensions import Self -from nifreeze.data.base import BaseDataset, _cmp, _data_repr +from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_ndim from nifreeze.utils.ndimage import load_api +ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None" +"""PET initialization array attribute absence error message.""" + +ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a numpy array." +"""PET initialization array attribute object error message.""" + +ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array." +"""PET initialization array attribute ndim error message.""" + +SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a scalar." +"""PET initialization scalar attribute shape error message.""" + +ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG = ( + "PET '{attribute}' length ({attr_len}) does not match number of frames ({data_frames})" +) +"""PET attribute shape mismatch error message.""" + + +def validate_1d_array(inst: PET, attr: attrs.Attribute, value: Any) -> None: + """Strict validator to ensure an attribute is a 1D NumPy array. + + Enforces that ``value`` is a :obj:`~numpy.ndarray` and that it has exactly + one dimension (``value.ndim == 1``). + + This function is intended for use as an attrs-style validator. + + Parameters + ---------- + inst : :obj:`~nifreeze.data.pet.PET` + The instance being validated (unused; present for validator signature). + attr : :obj:`~attrs.Attribute` + The attribute being validated; ``attr.name`` is used in the error message. + value : :obj:`Any` + The value to validate. + + Raises + ------ + exc:`TypeError` + If the input cannot be converted to a float :obj:`~numpy.ndarray`. + exc:`ValueError` + If the value is ``None``, or not 1D. + """ + + if value is None: + raise ValueError(ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name)) + + if not isinstance(value, np.ndarray): + raise TypeError(ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name)) + + if not _has_ndim(value, 1): + raise ValueError(ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr.name)) + + +def validate_scalar(inst: PET, attr: attrs.Attribute, value: Any) -> None: + """Strict validator to ensure an attribute is a scalar number. + + Ensures that ``value`` is a Python integer or floating point number, or a + NumPy scalar numeric type (e.g., :obj:`numpy.integer`, :obj:`numpy.floating`). + + This function is intended for use as an attrs-style validator. + + Parameters + ---------- + inst : :obj:`~nifreeze.data.pet.PET` + The instance being validated (unused; present for validator signature). + attr : :obj:`~attrs.Attribute` + The attribute being validated; attr.name is used in the error message. + value : :obj:`Any` + The value to validate. + + Raises + ------ + exc:`ValueError` + If ``value`` is not an int/float or a NumPy numeric scalar type. + """ + if not isinstance(value, (int, float, np.integer, np.floating)): + raise ValueError(SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name)) + @attrs.define(slots=True) class PET(BaseDataset[np.ndarray]): """Data representation structure for PET data.""" - midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)) + midframe: np.ndarray = attrs.field( + default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array + ) """A (N,) numpy array specifying the midpoint timing of each sample or frame.""" - total_duration: float = attrs.field(default=None, repr=True) + total_duration: float = attrs.field(default=None, repr=True, validator=validate_scalar) """A float representing the total duration of the dataset.""" - uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)) + uptake: np.ndarray = attrs.field( + default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array + ) """A (N,) numpy array specifying the uptake value of each sample or frame.""" + def __attrs_post_init__(self) -> None: + """Enforce presence and basic consistency of required PET fields at + instantiation time. + + Specifically, the length of the midframe and uptake attributes must + match the last dimension of the data (number of frames). + """ + data_frames = int(self.dataobj.shape[-1]) + + if len(self.midframe) != data_frames: + raise ValueError( + ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format( + attribute=attrs.fields_dict(self.__class__)["midframe"].name, + attr_len=len(self.midframe), + data_frames=data_frames, + ) + ) + + if len(self.uptake) != data_frames: + raise ValueError( + ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format( + attribute=attrs.fields_dict(self.__class__)["uptake"].name, + attr_len=len(self.uptake), + data_frames=data_frames, + ) + ) + def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]: return (self.midframe[idx],) @@ -258,37 +367,46 @@ def from_nii( raise NotImplementedError filename = Path(filename) - # Load from NIfTI + + # 1) Load a NIfTI img = load_api(filename, SpatialImage) - data = img.get_fdata(dtype=np.float32) - pet_obj = PET( - dataobj=data, - affine=img.affine, - ) + fulldata = img.get_fdata(dtype=np.float32) + + # 2) Determine uptake value + uptake = _compute_uptake_statistic(fulldata, stat_func=np.sum) - pet_obj.uptake = _compute_uptake_statistic(data, stat_func=np.sum) + # 3) Compute temporal features # Convert to a float32 numpy array and zero out the earliest time frame_time_arr = np.array(frame_time, dtype=np.float32) frame_time_arr -= frame_time_arr[0] - pet_obj.midframe = frame_time_arr + midframe = frame_time_arr # If the user doesn't provide frame_duration, we derive it: if frame_duration is None: - durations = _compute_frame_duration(pet_obj.midframe) + durations = _compute_frame_duration(midframe) else: durations = np.array(frame_duration, dtype=np.float32) - # Set total_duration and shift frame_time to the midpoint - pet_obj.total_duration = float(frame_time_arr[-1] + durations[-1]) - pet_obj.midframe = frame_time_arr + 0.5 * durations + # Compute total_duration and shift midframe to the midpoint + total_duration = float(frame_time_arr[-1] + durations[-1]) + midframe = frame_time_arr + 0.5 * durations - # If a brain mask is provided, load and attach + # 4) If a brainmask_file was provided, load it + brainmask_data = None if brainmask_file is not None: mask_img = load_api(brainmask_file, SpatialImage) - pet_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool) + brainmask_data = np.asanyarray(mask_img.dataobj, dtype=bool) - return pet_obj + # 5) Create and return the DWI instance. + return PET( + dataobj=fulldata, + affine=img.affine, + brainmask=brainmask_data, + midframe=midframe, + total_duration=total_duration, + uptake=uptake, + ) def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray: diff --git a/test/test_data_pet.py b/test/test_data_pet.py index f03861a29..ce6ad2639 100644 --- a/test/test_data_pet.py +++ b/test/test_data_pet.py @@ -22,6 +22,7 @@ # import json +import re from pathlib import Path import nibabel as nb @@ -29,10 +30,37 @@ import pytest from nitransforms.linear import Affine -from nifreeze.data.pet import PET, _compute_frame_duration, _compute_uptake_statistic, from_nii +from nifreeze.data.pet import ( + ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG, + ARRAY_ATTRIBUTE_NDIM_ERROR_MSG, + ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG, + ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG, + PET, + SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG, + _compute_frame_duration, + _compute_uptake_statistic, + from_nii, +) from nifreeze.utils.ndimage import load_api +def _pet_data_to_nifti(pet_dataobj, affine, brainmask_dataobj): + pet = nb.Nifti1Image(pet_dataobj, affine) + brainmask = nb.Nifti1Image(brainmask_dataobj, affine) + + return pet, brainmask + + +def _serialize_pet_data(pet, brainmask, _tmp_path): + pet_fname = _tmp_path / "pet.nii.gz" + brainmask_fname = _tmp_path / "brainmask.nii.gz" + + nb.save(pet, pet_fname) + nb.save(brainmask, brainmask_fname) + + return pet_fname, brainmask_fname + + @pytest.fixture def random_dataset(setup_random_pet_data) -> PET: """Create a PET dataset with random data for testing.""" @@ -63,6 +91,114 @@ def random_nifti_file(tmp_path, setup_random_uniform_spatial_data) -> Path: return _filename +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +@pytest.mark.parametrize( + "midframe, expected_exc, expected_msg", + [ + (None, ValueError, ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG), + (1, TypeError, ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG), + ], +) +def test_pet_attribute_basic_errors( + setup_random_uniform_spatial_data, midframe, expected_exc, expected_msg +): + data, affine = setup_random_uniform_spatial_data + with pytest.raises(expected_exc, match=expected_msg.format(attribute="midframe")): + PET(dataobj=data, affine=affine, midframe=midframe) # type: ignore[arg-type] + + +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +@pytest.mark.parametrize("size", ((2, 1), (1, 2), (3, 1), (3, 2))) +def test_pet_midframe_shape_error(setup_random_uniform_spatial_data, size): + data, affine = setup_random_uniform_spatial_data + midframe = np.zeros(size, dtype=np.float32) + with pytest.raises( + ValueError, match=ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute="midframe") + ): + PET(dataobj=data, affine=affine, midframe=midframe) + + +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +@pytest.mark.parametrize("size", ((2, 1), (1, 1))) +def test_pet_total_duration_error(request, setup_random_uniform_spatial_data, size): + data, affine = setup_random_uniform_spatial_data + midframe = np.zeros(data.shape[-1], dtype=np.float32) + total_duration = request.node.rng.uniform(5.0, 20.0, size=size) + with pytest.raises( + ValueError, match=SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute="total_duration") + ): + PET(dataobj=data, affine=affine, midframe=midframe, total_duration=total_duration) + + +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +@pytest.mark.parametrize("size", ((2, 1), (1, 2), (3, 1), (3, 2))) +def test_pet_uptake_shape_error(setup_random_uniform_spatial_data, size): + data, affine = setup_random_uniform_spatial_data + midframe = np.zeros(data.shape[-1], dtype=np.float32) + total_duration = 16.2 + uptake = np.zeros(size, dtype=np.float32) + with pytest.raises( + ValueError, match=ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute="uptake") + ): + PET( + dataobj=data, + affine=affine, + midframe=midframe, + total_duration=total_duration, + uptake=uptake, + ) + + +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +def test_pet_midframe_length_mismatch(setup_random_uniform_spatial_data): + data, affine = setup_random_uniform_spatial_data + total_duration = 16.2 + data_frames = data.shape[-1] + attr_len = data_frames + 1 + midframe = np.zeros(attr_len, dtype=np.float32) + uptake = np.zeros(data_frames, dtype=np.float32) + with pytest.raises( + ValueError, + match=re.escape( + ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format( + attribute="midframe", attr_len=attr_len, data_frames=data_frames + ) + ), + ): + PET( + dataobj=data, + affine=affine, + midframe=midframe, + total_duration=total_duration, + uptake=uptake, + ) + + +@pytest.mark.random_uniform_spatial_data((2, 2, 2, 4), 0.0, 1.0) +def test_pet_uptake_length_mismatch(setup_random_uniform_spatial_data): + data, affine = setup_random_uniform_spatial_data + total_duration = 16.2 + data_frames = data.shape[-1] + midframe = np.zeros(data_frames, dtype=np.float32) + attr_len = data_frames + 1 + uptake = np.zeros(attr_len, dtype=np.float32) + with pytest.raises( + ValueError, + match=re.escape( + ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format( + attribute="uptake", attr_len=attr_len, data_frames=data_frames + ) + ), + ): + PET( + dataobj=data, + affine=affine, + midframe=midframe, + total_duration=total_duration, + uptake=uptake, + ) + + @pytest.mark.parametrize( "midframe, expected", [ @@ -175,6 +311,84 @@ def test_round_trip(tmp_path, random_nifti_file, frame_time, frame_duration): assert units[0] == "mm" +def test_equality_operator(tmp_path, setup_random_pet_data): + ( + pet_dataobj, + affine, + brainmask_dataobj, + midframe, + total_duration, + ) = setup_random_pet_data + + pet, brainmask = _pet_data_to_nifti( + pet_dataobj, + affine, + brainmask_dataobj.astype(np.uint8), + ) + + ( + pet_fname, + brainmask_fname, + ) = _serialize_pet_data( + pet, + brainmask, + tmp_path, + ) + + # ToDo + # All this needs to be done automatically: the from_nii and PET data + # instantiation need to be refactored + + start_time = 0.0 + mid = np.asarray(midframe, dtype=np.float32) + frame_time = (mid + np.float32(start_time)).astype(np.float32) + frame_duration = _compute_frame_duration(midframe) + + # Read back using public API + pet_obj_from_nii = from_nii( + pet_fname, + frame_time=frame_time, + brainmask_file=brainmask_fname, + frame_duration=frame_duration, + ) + + # Direct instantiation with the same arrays + uptake = _compute_uptake_statistic(pet_dataobj, stat_func=np.sum) + pet_obj_direct = PET( + dataobj=pet_dataobj, + affine=affine, + brainmask=brainmask_dataobj, + midframe=midframe, + total_duration=total_duration, + uptake=uptake, + ) + + # Sanity checks (element-wise) + assert np.allclose(pet_obj_direct.dataobj, pet_obj_from_nii.dataobj) + assert np.allclose(pet_obj_direct.affine, pet_obj_from_nii.affine) + if pet_obj_direct.brainmask is None or pet_obj_from_nii.brainmask is None: + assert pet_obj_direct.brainmask is None + assert pet_obj_from_nii.brainmask is None + else: + assert np.array_equal(pet_obj_direct.brainmask, pet_obj_from_nii.brainmask) + assert np.allclose(pet_obj_direct.midframe, pet_obj_from_nii.midframe) + assert np.allclose(pet_obj_direct.total_duration, pet_obj_from_nii.total_duration) + assert np.allclose(pet_obj_direct.uptake, pet_obj_from_nii.uptake) + + # Test equality operator + assert pet_obj_direct == pet_obj_from_nii + + # Test equality operator against an instance from HDF5 + hdf5_filename = tmp_path / "test_pet.h5" + pet_obj_from_nii.to_filename(hdf5_filename) + + round_trip_pet_obj = PET.from_filename(hdf5_filename) + + # Symmetric equality + assert pet_obj_from_nii == round_trip_pet_obj + assert round_trip_pet_obj == pet_obj_from_nii + + @pytest.mark.random_pet_data(5, (4, 4, 4), np.asarray([10.0, 20.0, 30.0, 40.0, 50.0]), 60.0) def test_pet_set_transform_updates_motion_affines(random_dataset): idx = 2