|
38 | 38 | from nitransforms.resampling import apply |
39 | 39 | from typing_extensions import Self |
40 | 40 |
|
41 | | -from nifreeze.data.base import BaseDataset, _cmp, _data_repr |
| 41 | +from nifreeze.data.base import BaseDataset, _cmp, _data_repr, _has_ndim |
42 | 42 | from nifreeze.utils.ndimage import load_api |
43 | 43 |
|
| 44 | +ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG = "PET '{attribute}' may not be None" |
| 45 | +"""PET initialization array attribute absence error message.""" |
| 46 | + |
| 47 | +ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a numpy array." |
| 48 | +"""PET initialization array attribute object error message.""" |
| 49 | + |
| 50 | +ARRAY_ATTRIBUTE_NDIM_ERROR_MSG = "PET '{attribute}' must be a 1D numpy array." |
| 51 | +"""PET initialization array attribute ndim error message.""" |
| 52 | + |
| 53 | +SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG = "PET '{attribute}' must be a scalar." |
| 54 | +"""PET initialization scalar attribute shape error message.""" |
| 55 | + |
| 56 | +ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG = ( |
| 57 | + "PET '{attribute}' length ({attr_len}) does not match number of frames ({data_frames})" |
| 58 | +) |
| 59 | +"""PET attribute shape mismatch error message.""" |
| 60 | + |
| 61 | + |
| 62 | +def validate_1d_array(inst: PET, attr: attrs.Attribute, value: Any) -> None: |
| 63 | + """Strict validator to ensure an attribute is a 1D NumPy array. |
| 64 | +
|
| 65 | + Enforces that ``value`` is a :obj:`~numpy.ndarray` and that it has exactly |
| 66 | + one dimension (``value.ndim == 1``). |
| 67 | +
|
| 68 | + This function is intended for use as an attrs-style validator. |
| 69 | +
|
| 70 | + Parameters |
| 71 | + ---------- |
| 72 | + inst : :obj:`~nifreeze.data.pet.PET` |
| 73 | + The instance being validated (unused; present for validator signature). |
| 74 | + attr : :obj:`~attrs.Attribute` |
| 75 | + The attribute being validated; ``attr.name`` is used in the error message. |
| 76 | + value : :obj:`Any` |
| 77 | + The value to validate. |
| 78 | +
|
| 79 | + Raises |
| 80 | + ------ |
| 81 | + exc:`TypeError` |
| 82 | + If the input cannot be converted to a float :obj:`~numpy.ndarray`. |
| 83 | + exc:`ValueError` |
| 84 | + If the value is ``None``, or not 1D. |
| 85 | + """ |
| 86 | + |
| 87 | + if value is None: |
| 88 | + raise ValueError(ARRAY_ATTRIBUTE_ABSENCE_ERROR_MSG.format(attribute=attr.name)) |
| 89 | + |
| 90 | + if not isinstance(value, np.ndarray): |
| 91 | + raise TypeError(ARRAY_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name)) |
| 92 | + |
| 93 | + if not _has_ndim(value, 1): |
| 94 | + raise ValueError(ARRAY_ATTRIBUTE_NDIM_ERROR_MSG.format(attribute=attr.name)) |
| 95 | + |
| 96 | + |
| 97 | +def validate_scalar(inst: PET, attr: attrs.Attribute, value: Any) -> None: |
| 98 | + """Strict validator to ensure an attribute is a scalar number. |
| 99 | +
|
| 100 | + Ensures that ``value`` is a Python integer or floating point number, or a |
| 101 | + NumPy scalar numeric type (e.g., :obj:`numpy.integer`, :obj:`numpy.floating`). |
| 102 | +
|
| 103 | + This function is intended for use as an attrs-style validator. |
| 104 | +
|
| 105 | + Parameters |
| 106 | + ---------- |
| 107 | + inst : :obj:`~nifreeze.data.pet.PET` |
| 108 | + The instance being validated (unused; present for validator signature). |
| 109 | + attr : :obj:`~attrs.Attribute` |
| 110 | + The attribute being validated; attr.name is used in the error message. |
| 111 | + value : :obj:`Any` |
| 112 | + The value to validate. |
| 113 | +
|
| 114 | + Raises |
| 115 | + ------ |
| 116 | + exc:`ValueError` |
| 117 | + If ``value`` is not an int/float or a NumPy numeric scalar type. |
| 118 | + """ |
| 119 | + if not isinstance(value, (int, float, np.integer, np.floating)): |
| 120 | + raise ValueError(SCALAR_ATTRIBUTE_OBJECT_ERROR_MSG.format(attribute=attr.name)) |
| 121 | + |
44 | 122 |
|
45 | 123 | @attrs.define(slots=True) |
46 | 124 | class PET(BaseDataset[np.ndarray]): |
47 | 125 | """Data representation structure for PET data.""" |
48 | 126 |
|
49 | | - midframe: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)) |
| 127 | + midframe: np.ndarray = attrs.field( |
| 128 | + default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array |
| 129 | + ) |
50 | 130 | """A (N,) numpy array specifying the midpoint timing of each sample or frame.""" |
51 | | - total_duration: float = attrs.field(default=None, repr=True) |
| 131 | + total_duration: float = attrs.field(default=None, repr=True, validator=validate_scalar) |
52 | 132 | """A float representing the total duration of the dataset.""" |
53 | | - uptake: np.ndarray = attrs.field(default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp)) |
| 133 | + uptake: np.ndarray = attrs.field( |
| 134 | + default=None, repr=_data_repr, eq=attrs.cmp_using(eq=_cmp), validator=validate_1d_array |
| 135 | + ) |
54 | 136 | """A (N,) numpy array specifying the uptake value of each sample or frame.""" |
55 | 137 |
|
| 138 | + def __attrs_post_init__(self) -> None: |
| 139 | + """Enforce presence and basic consistency of required PET fields at |
| 140 | + instantiation time. |
| 141 | +
|
| 142 | + Specifically, the length of the midframe and uptake attributes must |
| 143 | + match the last dimension of the data (number of frames). |
| 144 | + """ |
| 145 | + data_frames = int(self.dataobj.shape[-1]) |
| 146 | + |
| 147 | + if len(self.midframe) != data_frames: |
| 148 | + raise ValueError( |
| 149 | + ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format( |
| 150 | + attribute=attrs.fields_dict(self.__class__)["midframe"].name, |
| 151 | + attr_len=len(self.midframe), |
| 152 | + data_frames=data_frames, |
| 153 | + ) |
| 154 | + ) |
| 155 | + |
| 156 | + if len(self.uptake) != data_frames: |
| 157 | + raise ValueError( |
| 158 | + ATTRIBUTE_SHAPE_MISMATCH_ERROR_MSG.format( |
| 159 | + attribute=attrs.fields_dict(self.__class__)["uptake"].name, |
| 160 | + attr_len=len(self.uptake), |
| 161 | + data_frames=data_frames, |
| 162 | + ) |
| 163 | + ) |
| 164 | + |
56 | 165 | def _getextra(self, idx: int | slice | tuple | np.ndarray) -> tuple[np.ndarray]: |
57 | 166 | return (self.midframe[idx],) |
58 | 167 |
|
@@ -258,37 +367,46 @@ def from_nii( |
258 | 367 | raise NotImplementedError |
259 | 368 |
|
260 | 369 | filename = Path(filename) |
261 | | - # Load from NIfTI |
| 370 | + |
| 371 | + # 1) Load a NIfTI |
262 | 372 | img = load_api(filename, SpatialImage) |
263 | | - data = img.get_fdata(dtype=np.float32) |
264 | | - pet_obj = PET( |
265 | | - dataobj=data, |
266 | | - affine=img.affine, |
267 | | - ) |
| 373 | + fulldata = img.get_fdata(dtype=np.float32) |
| 374 | + |
| 375 | + # 2) Determine uptake value |
| 376 | + uptake = _compute_uptake_statistic(fulldata, stat_func=np.sum) |
268 | 377 |
|
269 | | - pet_obj.uptake = _compute_uptake_statistic(data, stat_func=np.sum) |
| 378 | + # 3) Compute temporal features |
270 | 379 |
|
271 | 380 | # Convert to a float32 numpy array and zero out the earliest time |
272 | 381 | frame_time_arr = np.array(frame_time, dtype=np.float32) |
273 | 382 | frame_time_arr -= frame_time_arr[0] |
274 | | - pet_obj.midframe = frame_time_arr |
| 383 | + midframe = frame_time_arr |
275 | 384 |
|
276 | 385 | # If the user doesn't provide frame_duration, we derive it: |
277 | 386 | if frame_duration is None: |
278 | | - durations = _compute_frame_duration(pet_obj.midframe) |
| 387 | + durations = _compute_frame_duration(midframe) |
279 | 388 | else: |
280 | 389 | durations = np.array(frame_duration, dtype=np.float32) |
281 | 390 |
|
282 | | - # Set total_duration and shift frame_time to the midpoint |
283 | | - pet_obj.total_duration = float(frame_time_arr[-1] + durations[-1]) |
284 | | - pet_obj.midframe = frame_time_arr + 0.5 * durations |
| 391 | + # Compute total_duration and shift midframe to the midpoint |
| 392 | + total_duration = float(frame_time_arr[-1] + durations[-1]) |
| 393 | + midframe = frame_time_arr + 0.5 * durations |
285 | 394 |
|
286 | | - # If a brain mask is provided, load and attach |
| 395 | + # 4) If a brainmask_file was provided, load it |
| 396 | + brainmask_data = None |
287 | 397 | if brainmask_file is not None: |
288 | 398 | mask_img = load_api(brainmask_file, SpatialImage) |
289 | | - pet_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool) |
| 399 | + brainmask_data = np.asanyarray(mask_img.dataobj, dtype=bool) |
290 | 400 |
|
291 | | - return pet_obj |
| 401 | + # 5) Create and return the DWI instance. |
| 402 | + return PET( |
| 403 | + dataobj=fulldata, |
| 404 | + affine=img.affine, |
| 405 | + brainmask=brainmask_data, |
| 406 | + midframe=midframe, |
| 407 | + total_duration=total_duration, |
| 408 | + uptake=uptake, |
| 409 | + ) |
292 | 410 |
|
293 | 411 |
|
294 | 412 | def _compute_frame_duration(midframe: np.ndarray) -> np.ndarray: |
|
0 commit comments