3838from nitransforms .resampling import apply
3939from typing_extensions import Self
4040
41- from nifreeze .data .base import BaseDataset , _cmp , _data_repr
41+ from nifreeze .data .base import BaseDataset , _cmp , _data_repr , _has_ndim
4242from nifreeze .utils .ndimage import load_api
4343
44+ ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None"
45+ """PET initialization array attribute absence error message."""
46+
47+ ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a numpy array."
48+ """PET initialization array attribute object error message."""
49+
50+ ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array."
51+ """PET initialization array attribute ndim error message."""
52+
53+ SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a scalar."
54+ """PET initialization scalar attribute shape error message."""
55+
56+ ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG = (
57+ "PET '{attribute}' length ({attr_len}) does not match number of frames ({data_frames})"
58+ )
59+ """PET attribute shape mismatch error message."""
60+
61+
62+ def validate_1d_array (inst : PET , attr : attrs .Attribute , value : Any ) -> None :
63+ """Strict validator to ensure an attribute is a 1D NumPy array.
64+
65+ Enforces that ``value`` is a :obj:`~numpy.ndarray` and that it has exactly
66+ one dimension (``value.ndim == 1``).
67+
68+ This function is intended for use as an attrs-style validator.
69+
70+ Parameters
71+ ----------
72+ inst : :obj:`~nifreeze.data.pet.PET`
73+ The instance being validated (unused; present for validator signature).
74+ attr : :obj:`~attrs.Attribute`
75+ The attribute being validated; ``attr.name`` is used in the error message.
76+ value : :obj:`Any`
77+ The value to validate.
78+
79+ Raises
80+ ------
81+ exc:`TypeError`
82+ If the input cannot be converted to a float :obj:`~numpy.ndarray`.
83+ exc:`ValueError`
84+ If the value is ``None``, or not 1D.
85+ """
86+
87+ if value is None :
88+ raise ValueError (ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG .format (attribute = attr .name ))
89+
90+ if not isinstance (value , np .ndarray ):
91+ raise TypeError (ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG .format (attribute = attr .name ))
92+
93+ if not _has_ndim (value , 1 ):
94+ raise ValueError (ARRAY_ATTRIBUTE_NDIM_ERROR_MSG .format (attribute = attr .name ))
95+
96+
97+ def validate_scalar (inst : PET , attr : attrs .Attribute , value : Any ) -> None :
98+ """Strict validator to ensure an attribute is a scalar number.
99+
100+ Ensures that ``value`` is a Python integer or floating point number, or a
101+ NumPy scalar numeric type (e.g., :obj:`numpy.integer`, :obj:`numpy.floating`).
102+
103+ This function is intended for use as an attrs-style validator.
104+
105+ Parameters
106+ ----------
107+ inst : :obj:`~nifreeze.data.pet.PET`
108+ The instance being validated (unused; present for validator signature).
109+ attr : :obj:`~attrs.Attribute`
110+ The attribute being validated; attr.name is used in the error message.
111+ value : :obj:`Any`
112+ The value to validate.
113+
114+ Raises
115+ ------
116+ exc:`ValueError`
117+ If ``value`` is not an int/float or a NumPy numeric scalar type.
118+ """
119+ if not isinstance (value , (int , float , np .integer , np .floating )):
120+ raise ValueError (SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG .format (attribute = attr .name ))
121+
44122
45123@attrs .define (slots = True )
46124class PET (BaseDataset [np .ndarray ]):
47- """Data representation structure for PET data."""
125+ """Data representation structure for PET data.
126+
127+ If not provided, frame duration data are computed as differences between
128+ consecutive midframe times. The last interval is duplicated.
129+ """
48130
49- midframe : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
131+ frame_time : np .ndarray = attrs .field (
132+ default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ), validator = validate_1d_array
133+ )
134+ """A (N,) numpy array specifying the timing of each sample or frame."""
135+ uptake : np .ndarray = attrs .field (
136+ default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ), validator = validate_1d_array
137+ )
138+ """A (N,) numpy array specifying the uptake value of each sample or frame."""
139+ frame_duration : np .ndarray | None = attrs .field (
140+ default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp )
141+ )
142+ """A (N,) numpy array specifying the frame duration."""
143+ midframe : np .ndarray = attrs .field (
144+ default = None , repr = _data_repr , init = False , eq = attrs .cmp_using (eq = _cmp )
145+ )
50146 """A (N,) numpy array specifying the midpoint timing of each sample or frame."""
51- total_duration : float = attrs .field (default = None , repr = True )
147+ total_duration : float = attrs .field (default = None , repr = True , init = False )
52148 """A float representing the total duration of the dataset."""
53- uptake : np .ndarray = attrs .field (default = None , repr = _data_repr , eq = attrs .cmp_using (eq = _cmp ))
54- """A (N,) numpy array specifying the uptake value of each sample or frame."""
149+
150+ def __attrs_post_init__ (self ) -> None :
151+ """Enforce presence and basic consistency of PET data fields at
152+ instantiation time.
153+
154+ Specifically, the length of the frame_time and uptake attributes must
155+ match the last dimension of the data (number of frames).
156+
157+ Computes the values for the private attributes.
158+ """
159+ data_frames = int (self .dataobj .shape [- 1 ])
160+
161+ if len (self .frame_time ) != data_frames :
162+ raise ValueError (
163+ ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG .format (
164+ attribute = attrs .fields_dict (self .__class__ )["frame_time" ].name ,
165+ attr_len = len (self .frame_time ),
166+ data_frames = data_frames ,
167+ )
168+ )
169+
170+ if len (self .uptake ) != data_frames :
171+ raise ValueError (
172+ ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG .format (
173+ attribute = attrs .fields_dict (self .__class__ )["uptake" ].name ,
174+ attr_len = len (self .uptake ),
175+ data_frames = data_frames ,
176+ )
177+ )
178+
179+ # Compute temporal attributes
180+
181+ # Convert to a float32 numpy array and zero out the earliest time
182+ frame_time_arr = np .array (self .frame_time , dtype = np .float32 )
183+ frame_time_arr -= frame_time_arr [0 ]
184+ self .midframe = frame_time_arr
185+
186+ # If the user did not provide frame duration values,compute them
187+ if self .frame_duration :
188+ durations = np .array (self .frame_duration , dtype = np .float32 )
189+ else :
190+ durations = _compute_frame_duration (self .midframe )
191+
192+ # Compute total duration and shift midframe to the midpoint
193+ self .total_duration = float (self .midframe [- 1 ] + durations [- 1 ])
194+ self .midframe = self .midframe + 0.5 * durations
55195
56196 def _getextra (self , idx : int | slice | tuple | np .ndarray ) -> tuple [np .ndarray ]:
57197 return (self .midframe [idx ],)
@@ -223,6 +363,7 @@ def from_nii(
223363 brainmask_file : Path | str | None = None ,
224364 motion_file : Path | str | None = None ,
225365 frame_duration : np .ndarray | list [float ] | None = None ,
366+ uptake_stat_func : Callable [..., np .ndarray ] | None = np .sum ,
226367) -> PET :
227368 """
228369 Load PET data from NIfTI, creating a PET object with appropriate metadata.
@@ -242,6 +383,8 @@ def from_nii(
242383 The duration of each frame.
243384 If ``None``, it is derived by the difference of consecutive frame times,
244385 defaulting the last frame to match the second-last.
386+ uptake_stat_func : :obj:`Callable`, optional
387+ The statistic function to be used to compute the uptake value.
245388
246389 Returns
247390 -------
@@ -258,37 +401,29 @@ def from_nii(
258401 raise NotImplementedError
259402
260403 filename = Path (filename )
261- # Load from NIfTI
262- img = load_api (filename , SpatialImage )
263- data = img .get_fdata (dtype = np .float32 )
264- pet_obj = PET (
265- dataobj = data ,
266- affine = img .affine ,
267- )
268-
269- pet_obj .uptake = _compute_uptake_statistic (data , stat_func = np .sum )
270-
271- # Convert to a float32 numpy array and zero out the earliest time
272- frame_time_arr = np .array (frame_time , dtype = np .float32 )
273- frame_time_arr -= frame_time_arr [0 ]
274- pet_obj .midframe = frame_time_arr
275404
276- # If the user doesn't provide frame_duration, we derive it:
277- if frame_duration is None :
278- durations = _compute_frame_duration (pet_obj .midframe )
279- else :
280- durations = np .array (frame_duration , dtype = np .float32 )
405+ # 1) Load a NIfTI
406+ img = load_api (filename , SpatialImage )
407+ fulldata = img .get_fdata (dtype = np .float32 )
281408
282- # Set total_duration and shift frame_time to the midpoint
283- pet_obj .total_duration = float (frame_time_arr [- 1 ] + durations [- 1 ])
284- pet_obj .midframe = frame_time_arr + 0.5 * durations
409+ # 2) Determine uptake value
410+ uptake = _compute_uptake_statistic (fulldata , stat_func = uptake_stat_func )
285411
286- # If a brain mask is provided, load and attach
412+ # 3) If a brainmask_file was provided, load it
413+ brainmask_data = None
287414 if brainmask_file is not None :
288415 mask_img = load_api (brainmask_file , SpatialImage )
289- pet_obj . brainmask = np .asanyarray (mask_img .dataobj , dtype = bool )
416+ brainmask_data = np .asanyarray (mask_img .dataobj , dtype = bool )
290417
291- return pet_obj
418+ # 4) Create and return the PET instance
419+ return PET (
420+ dataobj = fulldata ,
421+ affine = img .affine ,
422+ brainmask = brainmask_data ,
423+ frame_time = frame_time ,
424+ frame_duration = frame_duration ,
425+ uptake = uptake ,
426+ )
292427
293428
294429def _compute_frame_duration (midframe : np .ndarray ) -> np .ndarray :
0 commit comments