Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 137 additions & 19 deletions src/nifreeze/data/pet.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,130 @@
from nitransforms.resampling import apply
from typing_extensions import Self

from nifreeze.data.base import BaseDataset, _cmp, _data_repr
from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_ndim
from nifreeze.utils.ndimage import load_api

ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None"
"""PET initialization array attribute absence error message."""

ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a numpy array."
"""PET initialization array attribute object error message."""

ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array."
"""PET initialization array attribute ndim error message."""

SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a scalar."
"""PET initialization scalar attribute shape error message."""

ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG = (
"PET '{attribute}' length ({attr_len}) does not match number of frames ({data_frames})"
)
"""PET attribute shape mismatch error message."""


def validate_1d_array(inst: PET, attr: attrs.Attribute, value: Any) -> None:
"""Strict validator to ensure an attribute is a 1D NumPy array.

Enforces that ``value`` is a :obj:`~numpy.ndarray` and that it has exactly
one dimension (``value.ndim == 1``).

This function is intended for use as an attrs-style validator.

Parameters
----------
inst : :obj:`~nifreeze.data.pet.PET`
The instance being validated (unused; present for validator signature).
attr : :obj:`~attrs.Attribute`
The attribute being validated; ``attr.name`` is used in the error message.
value : :obj:`Any`
The value to validate.

Raises
------
exc:`TypeError`
If the input cannot be converted to a float :obj:`~numpy.ndarray`.
exc:`ValueError`
If the value is ``None``, or not 1D.
"""

if value is None:
raise ValueError(ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name))

if not isinstance(value, np.ndarray):
raise TypeError(ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name))

if not _has_ndim(value, 1):
raise ValueError(ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr.name))


def validate_scalar(inst: PET, attr: attrs.Attribute, value: Any) -> None:
"""Strict validator to ensure an attribute is a scalar number.

Ensures that ``value`` is a Python integer or floating point number, or a
NumPy scalar numeric type (e.g., :obj:`numpy.integer`, :obj:`numpy.floating`).

This function is intended for use as an attrs-style validator.

Parameters
----------
inst : :obj:`~nifreeze.data.pet.PET`
The instance being validated (unused; present for validator signature).
attr : :obj:`~attrs.Attribute`
The attribute being validated; attr.name is used in the error message.
value : :obj:`Any`
The value to validate.

Raises
------
exc:`ValueError`
If ``value`` is not an int/float or a NumPy numeric scalar type.
"""
if not isinstance(value, (int, float, np.integer, np.floating)):
raise ValueError(SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name))


@attrs.define(slots=True)
class PET(BaseDataset[np.ndarray]):
"""Data representation structure for PET data."""

midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
midframe: np.ndarray = attrs.field(
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array
)
"""A (N,) numpy array specifying the midpoint timing of each sample or frame."""
total_duration: float = attrs.field(default=None, repr=True)
total_duration: float = attrs.field(default=None, repr=True, validator=validate_scalar)
"""A float representing the total duration of the dataset."""
uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp))
uptake: np.ndarray = attrs.field(
default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array
)
"""A (N,) numpy array specifying the uptake value of each sample or frame."""

def __attrs_post_init__(self) -> None:
"""Enforce presence and basic consistency of required PET fields at
instantiation time.

Specifically, the length of the midframe and uptake attributes must
match the last dimension of the data (number of frames).
"""
data_frames = int(self.dataobj.shape[-1])

if len(self.midframe) != data_frames:
raise ValueError(
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
attribute=attrs.fields_dict(self.__class__)["midframe"].name,
attr_len=len(self.midframe),
data_frames=data_frames,
)
)

if len(self.uptake) != data_frames:
raise ValueError(
ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format(
attribute=attrs.fields_dict(self.__class__)["uptake"].name,
attr_len=len(self.uptake),
data_frames=data_frames,
)
)

def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]:
return (self.midframe[idx],)

Expand Down Expand Up @@ -258,37 +367,46 @@ def from_nii(
raise NotImplementedError

filename = Path(filename)
# Load from NIfTI

# 1) Load a NIfTI
img = load_api(filename, SpatialImage)
data = img.get_fdata(dtype=np.float32)
pet_obj = PET(
dataobj=data,
affine=img.affine,
)
fulldata = img.get_fdata(dtype=np.float32)

# 2) Determine uptake value
uptake = _compute_uptake_statistic(fulldata, stat_func=np.sum)

pet_obj.uptake = _compute_uptake_statistic(data, stat_func=np.sum)
# 3) Compute temporal features

# Convert to a float32 numpy array and zero out the earliest time
frame_time_arr = np.array(frame_time, dtype=np.float32)
frame_time_arr -= frame_time_arr[0]
pet_obj.midframe = frame_time_arr
midframe = frame_time_arr

# If the user doesn't provide frame_duration, we derive it:
if frame_duration is None:
durations = _compute_frame_duration(pet_obj.midframe)
durations = _compute_frame_duration(midframe)
else:
durations = np.array(frame_duration, dtype=np.float32)

# Set total_duration and shift frame_time to the midpoint
pet_obj.total_duration = float(frame_time_arr[-1] + durations[-1])
pet_obj.midframe = frame_time_arr + 0.5 * durations
# Compute total_duration and shift midframe to the midpoint
total_duration = float(frame_time_arr[-1] + durations[-1])
midframe = frame_time_arr + 0.5 * durations

# If a brain mask is provided, load and attach
# 4) If a brainmask_file was provided, load it
brainmask_data = None
if brainmask_file is not None:
mask_img = load_api(brainmask_file, SpatialImage)
pet_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool)
brainmask_data = np.asanyarray(mask_img.dataobj, dtype=bool)

return pet_obj
# 5) Create and return the DWI instance.
return PET(
dataobj=fulldata,
affine=img.affine,
brainmask=brainmask_data,
midframe=midframe,
total_duration=total_duration,
uptake=uptake,
)


def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray:
Expand Down
Loading
Loading