3838from nitransforms .resampling import apply
3939from typing_extensions import Self
4040
41- from nifreeze .data .base import BaseDataset , _cmp , _data_repr
41+ from nifreeze .data .base import BaseDataset , _cmp , _data_repr , _has_ndim
4242from nifreeze .utils .ndimage import load_api
4343
44+ ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None"
45+ """PET initialization array attribute absence error message."""
46+
47+ ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a numpy array."
48+ """PET initialization array attribute object error message."""
49+
50+ ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array."
51+ """PET initialization array attribute ndim error message."""
52+
53+ ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR = """\
54+ PET '{attribute}' length does not match number of frames: \
55+ expected {n_frames} values, found {attr_len}."""
56+ """PET attribute shape mismatch error message."""
57+
58+
59+ def validate_1d_array (inst : PET , attr : attrs .Attribute , value : Any ) -> None :
60+ """Strict validator to ensure an attribute is a 1D NumPy array.
61+
62+ Enforces that ``value`` is a :obj:`~numpy.ndarray` and that it has exactly
63+ one dimension (``value.ndim == 1``).
64+
65+ This function is intended for use as an attrs-style validator.
66+
67+ Parameters
68+ ----------
69+ inst : :obj:`~nifreeze.data.pet.PET`
70+ The instance being validated (unused; present for validator signature).
71+ attr : :obj:`~attrs.Attribute`
72+ The attribute being validated; ``attr.name`` is used in the error message.
73+ value : :obj:`Any`
74+ The value to validate.
75+
76+ Raises
77+ ------
78+ exc:`TypeError`
79+ If the input cannot be converted to a float :obj:`~numpy.ndarray`.
80+ exc:`ValueError`
81+ If the value is ``None``, or not 1D.
82+ """
83+
84+ if value is None :
85+ raise ValueError (ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG .format (attribute = attr .name ))
86+
87+ if not isinstance (value , np .ndarray ):
88+ raise TypeError (ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG .format (attribute = attr .name ))
89+
90+ if not _has_ndim (value , 1 ):
91+ raise ValueError (ARRAY_ATTRIBUTE_NDIM_ERROR_MSG .format (attribute = attr .name ))
92+
4493
4594@attrs .define (slots = True )
4695class PET (BaseDataset [np .ndarray ]):
47- """Data representation structure for PET data."""
96+ """Data representation structure for PET data.
97+
98+ If not provided, frame duration data are computed as differences between
99+ consecutive midframe times. The last interval is duplicated.
100+ """
48101
49- midframe : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
102+ frame_time : np .ndarray = attrs .field (
103+ default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ), validator = validate_1d_array
104+ )
105+ """A (N,) numpy array specifying the timing of each sample or frame."""
106+ uptake : np .ndarray = attrs .field (
107+ default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ), validator = validate_1d_array
108+ )
109+ """A (N,) numpy array specifying the uptake value of each sample or frame."""
110+ frame_duration : np .ndarray | None = attrs .field (
111+ default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp )
112+ )
113+ """A (N,) numpy array specifying the frame duration."""
114+ midframe : np .ndarray = attrs .field (
115+ default = None , repr = _data_repr , init = False , eq = attrs .cmp_using (eq = _cmp )
116+ )
50117 """A (N,) numpy array specifying the midpoint timing of each sample or frame."""
51- total_duration : float = attrs .field (default = None , repr = True )
118+ total_duration : float = attrs .field (default = None , repr = True , init = False )
52119 """A float representing the total duration of the dataset."""
53- uptake : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
54- """A (N,) numpy array specifying the uptake value of each sample or frame."""
120+
121+ def __attrs_post_init__ (self ) -> None :
122+ """Enforce presence and basic consistency of PET data fields at
123+ instantiation time.
124+
125+ Specifically, the length of the frame_time and uptake attributes must
126+ match the last dimension of the data (number of frames).
127+
128+ Computes the values for the private attributes.
129+ """
130+ n_frames = int (self .dataobj .shape [- 1 ])
131+
132+ if len (self .frame_time ) != n_frames :
133+ raise ValueError (
134+ ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR .format (
135+ attribute = attrs .fields_dict (self .__class__ )["frame_time" ].name ,
136+ n_frames = n_frames ,
137+ attr_len = len (self .frame_time ),
138+ )
139+ )
140+
141+ if len (self .uptake ) != n_frames :
142+ raise ValueError (
143+ ATTRIBUTE_VOLUME_DIMENSIONALITY_MISMATCH_ERROR .format (
144+ attribute = attrs .fields_dict (self .__class__ )["uptake" ].name ,
145+ n_frames = n_frames ,
146+ attr_len = len (self .uptake ),
147+ )
148+ )
149+
150+ # Compute temporal attributes
151+
152+ # Convert to a float32 numpy array and zero out the earliest time
153+ frame_time_arr = np .array (self .frame_time , dtype = np .float32 )
154+ frame_time_arr -= frame_time_arr [0 ]
155+ self .midframe = frame_time_arr
156+
157+ # If the user did not provide frame duration values,compute them
158+ if self .frame_duration :
159+ durations = np .array (self .frame_duration , dtype = np .float32 )
160+ else :
161+ durations = _compute_frame_duration (self .midframe )
162+
163+ # Compute total duration and shift midframe to the midpoint
164+ self .total_duration = float (self .midframe [- 1 ] + durations [- 1 ])
165+ self .midframe = self .midframe + 0.5 * durations
55166
56167 def _getextra (self , idx : int | slice | tuple | np .ndarray ) -> tuple [np .ndarray ]:
57168 return (self .midframe [idx ],)
@@ -223,6 +334,7 @@ def from_nii(
223334 brainmask_file : Path | str | None = None ,
224335 motion_file : Path | str | None = None ,
225336 frame_duration : np .ndarray | list [float ] | None = None ,
337+ uptake_stat_func : Callable [..., np .ndarray ] = np .sum ,
226338) -> PET :
227339 """
228340 Load PET data from NIfTI, creating a PET object with appropriate metadata.
@@ -242,6 +354,8 @@ def from_nii(
242354 The duration of each frame.
243355 If ``None``, it is derived by the difference of consecutive frame times,
244356 defaulting the last frame to match the second-last.
357+ uptake_stat_func : :obj:`Callable`, optional
358+ The statistic function to be used to compute the uptake value.
245359
246360 Returns
247361 -------
@@ -258,37 +372,29 @@ def from_nii(
258372 raise NotImplementedError
259373
260374 filename = Path (filename )
261- # Load from NIfTI
262- img = load_api (filename , SpatialImage )
263- data = img .get_fdata (dtype = np .float32 )
264- pet_obj = PET (
265- dataobj = data ,
266- affine = img .affine ,
267- )
268-
269- pet_obj .uptake = _compute_uptake_statistic (data , stat_func = np .sum )
270375
271- # Convert to a float32 numpy array and zero out the earliest time
272- frame_time_arr = np .array (frame_time , dtype = np .float32 )
273- frame_time_arr -= frame_time_arr [0 ]
274- pet_obj .midframe = frame_time_arr
275-
276- # If the user doesn't provide frame_duration, we derive it:
277- if frame_duration is None :
278- durations = _compute_frame_duration (pet_obj .midframe )
279- else :
280- durations = np .array (frame_duration , dtype = np .float32 )
376+ # 1) Load a NIfTI
377+ img = load_api (filename , SpatialImage )
378+ fulldata = img .get_fdata (dtype = np .float32 )
281379
282- # Set total_duration and shift frame_time to the midpoint
283- pet_obj .total_duration = float (frame_time_arr [- 1 ] + durations [- 1 ])
284- pet_obj .midframe = frame_time_arr + 0.5 * durations
380+ # 2) Determine uptake value
381+ uptake = _compute_uptake_statistic (fulldata , stat_func = uptake_stat_func )
285382
286- # If a brain mask is provided, load and attach
383+ # 3) If a brainmask_file was provided, load it
384+ brainmask_data = None
287385 if brainmask_file is not None :
288386 mask_img = load_api (brainmask_file , SpatialImage )
289- pet_obj . brainmask = np .asanyarray (mask_img .dataobj , dtype = bool )
387+ brainmask_data = np .asanyarray (mask_img .dataobj , dtype = bool )
290388
291- return pet_obj
389+ # 4) Create and return the PET instance
390+ return PET (
391+ dataobj = fulldata ,
392+ affine = img .affine ,
393+ brainmask = brainmask_data ,
394+ frame_time = np .asarray (frame_time ),
395+ frame_duration = np .asarray (frame_duration ),
396+ uptake = uptake ,
397+ )
292398
293399
294400def _compute_frame_duration (midframe : np .ndarray ) -> np .ndarray :
@@ -313,7 +419,7 @@ def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray:
313419 return durations
314420
315421
316- def _compute_uptake_statistic (data : np .ndarray , stat_func : Callable = np .sum ):
422+ def _compute_uptake_statistic (data : np .ndarray , stat_func : Callable [..., np . ndarray ] = np .sum ):
317423 """Compute a statistic over all voxels for each frame on a PET sequence.
318424
319425 Assumes the last dimension corresponds to the number of frames in the
0 commit comments