Skip to content

Commit 4ea4636

Browse files
committed
enh: adapt PET data representation class
1 parent f305783 commit 4ea4636

File tree

3 files changed

+217
-120
lines changed

3 files changed

+217
-120
lines changed

src/nifreeze/data/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,9 @@ class BaseDataset:
8888

8989
def __len__(self) -> int:
9090
"""Obtain the number of volumes/frames in the dataset."""
91+
if self.dataobj is None:
92+
return 0
93+
9194
return self.dataobj.shape[-1]
9295

9396
def __getitem__(

src/nifreeze/data/dmri.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,33 @@ class DWI(BaseDataset):
5353
eddy_xfms = attr.ib(default=None)
5454
"""List of transforms to correct for estimatted eddy current distortions."""
5555

56+
def __getitem__(
57+
self, idx: int | slice | tuple | np.ndarray
58+
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
59+
"""
60+
Returns volume(s) and corresponding affine(s) and gradient(s) through fancy indexing.
61+
62+
Parameters
63+
----------
64+
idx : :obj:`int` or :obj:`slice` or :obj:`tuple` or :obj:`~numpy.ndarray`
65+
Indexer for the last dimension (or possibly other dimensions if extended).
66+
67+
Returns
68+
-------
69+
volumes : np.ndarray
70+
The selected data subset. If `idx` is a single integer, this will have shape
71+
``(X, Y, Z)``, otherwise it may have shape ``(X, Y, Z, k)``.
72+
motion_affine : np.ndarray or None
73+
The corresponding per-volume motion affine(s) or `None` if identity transform(s).
74+
gradient : np.ndarray
75+
The corresponding gradient(s), which may have shape ``(4,)`` if a single volume
76+
or ``(k, 4)`` if multiple volumes, or None if gradients are not available.
77+
78+
"""
79+
80+
data, affine = super().__getitem__(idx)
81+
return data, affine, self.gradients[idx, ...]
82+
5683
def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
5784
"""
5885
Set an affine transform for a particular DWI volume, resample that volume,
@@ -149,6 +176,29 @@ def from_filename(cls, filename: Path | str) -> DWI:
149176

150177
return cls(**data)
151178

179+
def to_filename(
180+
self,
181+
filename: Path | str,
182+
compression: str | None = None,
183+
compression_opts: Any = None,
184+
) -> None:
185+
"""
186+
Write the dMRI dataset to an HDF5 file on disk.
187+
188+
Parameters
189+
----------
190+
filename : Path or str
191+
Path to the output HDF5 file.
192+
compression : str, optional
193+
Compression filter, e.g. 'gzip'. Default is None (no compression).
194+
compression_opts : Any, optional
195+
Compression level or other parameters for the HDF5 dataset.
196+
"""
197+
super().to_filename(filename, compression=compression, compression_opts=compression_opts)
198+
# Overriding if you'd like to set a custom attribute, for example:
199+
with h5py.File(filename, "r+") as out_file:
200+
out_file.attrs["Type"] = "dmri"
201+
152202
def to_nifti(self, filename: Path | str) -> None:
153203
"""
154204
Write a NIfTI 1.0 file to disk, and also write out the gradient table

src/nifreeze/data/pet.py

Lines changed: 164 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -22,157 +22,201 @@
2222
#
2323
"""PET data representation."""
2424

25-
from collections import namedtuple
25+
from __future__ import annotations
26+
2627
from pathlib import Path
27-
from tempfile import mkdtemp
28+
from typing import Any, Union
2829

2930
import attr
3031
import h5py
3132
import nibabel as nb
3233
import numpy as np
33-
from nitransforms.linear import Affine
34-
3534

36-
def _data_repr(value):
37-
if value is None:
38-
return "None"
39-
return f"<{'x'.join(str(v) for v in value.shape)} ({value.dtype})>"
35+
from nifreeze.data.base import BaseDataset, _cmp, _data_repr
4036

4137

4238
@attr.s(slots=True)
43-
class PET:
44-
"""Data representation structure for PET data."""
45-
46-
dataobj = attr.ib(default=None, repr=_data_repr)
47-
"""A numpy ndarray object for the data array, without *b=0* volumes."""
48-
affine = attr.ib(default=None, repr=_data_repr)
49-
"""Best affine for RAS-to-voxel conversion of coordinates (NIfTI header)."""
50-
brainmask = attr.ib(default=None, repr=_data_repr)
51-
"""A boolean ndarray object containing a corresponding brainmask."""
52-
frame_time = attr.ib(default=None, repr=_data_repr)
53-
"""A 1D numpy array with the midpoint timing of each sample."""
54-
total_duration = attr.ib(default=None, repr=_data_repr)
55-
"""A float number representing the total duration of acquisition."""
56-
57-
em_affines = attr.ib(default=None)
58-
"""
59-
List of :obj:`nitransforms.linear.Affine` objects that bring
60-
PET timepoints into alignment.
39+
class PET(BaseDataset):
6140
"""
62-
_filepath = attr.ib(
63-
factory=lambda: Path(mkdtemp()) / "em_cache.h5",
64-
repr=False,
65-
)
66-
"""A path to an HDF5 file to store the whole dataset."""
41+
Data representation structure for PET data, inheriting from BaseDataset.
6742
68-
def __len__(self):
69-
"""Obtain the number of high-*b* orientations."""
70-
return self.dataobj.shape[-1]
43+
In addition to the base attributes (e.g., dataobj, affine), this PET class stores:
44+
- frame_time: a 1D array specifying the midpoint timing of each frame.
45+
- total_duration: a float specifying the total acquisition duration.
7146
72-
def set_transform(self, index, affine, order=3):
73-
"""Set an affine, and update data object and gradients."""
74-
reference = namedtuple("ImageGrid", ("shape", "affine"))(
75-
shape=self.dataobj.shape[:3], affine=self.affine
76-
)
77-
xform = Affine(matrix=affine, reference=reference)
47+
"""
7848

79-
if not Path(self._filepath).exists():
80-
self.to_filename(self._filepath)
49+
frame_time: np.ndarray | None = attr.ib(
50+
default=None, repr=_data_repr, eq=attr.cmp_using(eq=_cmp)
51+
)
52+
"""
53+
A 1D numpy array specifying the midpoint timing of each sample or frame.
54+
Typically shape (N,).
55+
"""
56+
total_duration: float | None = attr.ib(default=None, repr=True)
57+
"""
58+
A float representing the total duration of the entire PET acquisition.
59+
"""
8160

82-
# read original PET
83-
with h5py.File(self._filepath, "r") as in_file:
84-
root = in_file["/0"]
85-
dframe = np.asanyarray(root["dataobj"][..., index])
61+
def __getitem__(
62+
self, idx: int | slice | tuple | np.ndarray
63+
) -> tuple[np.ndarray, np.ndarray | None, np.ndarray | None]:
64+
"""
65+
Returns volume(s) and corresponding affine(s) and timing(s) through fancy indexing.
8666
87-
dmoving = nb.Nifti1Image(dframe, self.affine, None)
67+
Parameters
68+
----------
69+
idx : :obj:`int` or :obj:`slice` or :obj:`tuple` or :obj:`~numpy.ndarray`
70+
Indexer for the last dimension (or possibly other dimensions if extended).
8871
89-
# resample and update orientation at index
90-
self.dataobj[..., index] = np.asanyarray(
91-
xform.apply(dmoving, order=order).dataobj,
92-
dtype=self.dataobj.dtype,
93-
)
72+
Returns
73+
-------
74+
volumes : np.ndarray
75+
The selected data subset. If `idx` is a single integer, this will have shape
76+
``(X, Y, Z)``, otherwise it may have shape ``(X, Y, Z, k)``.
77+
motion_affine : np.ndarray or None
78+
The corresponding per-volume motion affine(s) or `None` if identity transform(s).
79+
time : float
80+
The corresponding frame time.
9481
95-
# update transform
96-
if self.em_affines is None:
97-
self.em_affines = [None] * len(self)
82+
"""
9883

99-
self.em_affines[index] = xform
84+
data, affine = super().__getitem__(idx)
85+
return data, affine, self.frame_time[idx]
10086

101-
def to_filename(self, filename, compression=None, compression_opts=None):
102-
"""Write an HDF5 file to disk."""
103-
filename = Path(filename)
104-
if not filename.name.endswith(".h5"):
105-
filename = filename.parent / f"{filename.name}.h5"
106-
107-
with h5py.File(filename, "w") as out_file:
108-
out_file.attrs["Format"] = "EMC/PET"
109-
out_file.attrs["Version"] = np.uint16(1)
110-
root = out_file.create_group("/0")
111-
root.attrs["Type"] = "pet"
112-
for f in attr.fields(self.__class__):
113-
if f.name.startswith("_"):
114-
continue
87+
@classmethod
88+
def from_filename(cls, filename: Union[str, Path]) -> PET:
89+
"""
90+
Read an HDF5 file from disk and create a PET object.
91+
92+
Parameters
93+
----------
94+
filename : str or Path
95+
The HDF5 file path to read.
96+
97+
Returns
98+
-------
99+
PET
100+
A PET dataset with data loaded from the specified file.
101+
"""
102+
import attr
115103

116-
value = getattr(self, f.name)
117-
if value is not None:
118-
root.create_dataset(
119-
f.name,
120-
data=value,
121-
compression=compression,
122-
compression_opts=compression_opts,
123-
)
124-
125-
def to_nifti(self, filename, *_):
126-
"""Write a NIfTI 1.0 file to disk."""
127-
nii = nb.Nifti1Image(self.dataobj, self.affine, None)
128-
nii.header.set_xyzt_units("mm")
129-
nii.to_filename(filename)
104+
filename = Path(filename)
105+
data: dict[str, Any] = {}
130106

131-
@classmethod
132-
def from_filename(cls, filename):
133-
"""Read an HDF5 file from disk."""
134107
with h5py.File(filename, "r") as in_file:
135108
root = in_file["/0"]
136-
data = {k: np.asanyarray(v) for k, v in root.items() if not k.startswith("_")}
109+
for f in attr.fields(cls):
110+
# skip private attributes (start with '_')
111+
if f.name.startswith("_"):
112+
continue
113+
if f.name in root:
114+
data[f.name] = np.asanyarray(root[f.name])
115+
else:
116+
data[f.name] = None
117+
137118
return cls(**data)
138119

120+
def to_filename(
121+
self,
122+
filename: Path | str,
123+
compression: str | None = None,
124+
compression_opts: Any = None,
125+
) -> None:
126+
"""
127+
Write the PET dataset to an HDF5 file on disk.
128+
129+
Parameters
130+
----------
131+
filename : Path or str
132+
Path to the output HDF5 file.
133+
compression : str, optional
134+
Compression filter, e.g. 'gzip'. Default is None (no compression).
135+
compression_opts : Any, optional
136+
Compression level or other parameters for the HDF5 dataset.
137+
"""
138+
super().to_filename(filename, compression=compression, compression_opts=compression_opts)
139+
# Overriding if you'd like to set a custom attribute, for example:
140+
with h5py.File(filename, "r+") as out_file:
141+
out_file.attrs["Type"] = "pet"
142+
139143

140144
def load(
141-
filename,
142-
brainmask_file=None,
143-
frame_time=None,
144-
frame_duration=None,
145-
):
146-
"""Load PET data."""
145+
filename: Path | str,
146+
brainmask_file: Path | str | None = None,
147+
frame_time: np.ndarray | list[float] | None = None,
148+
frame_duration: np.ndarray | list[float] | None = None,
149+
) -> PET:
150+
"""
151+
Load PET data from HDF5 or NIfTI, creating a PET object with appropriate metadata.
152+
153+
Parameters
154+
----------
155+
filename : Path or str
156+
Path to the PET data (HDF5 or NIfTI).
157+
brainmask_file : Path or str, optional
158+
An optional brain mask NIfTI file.
159+
frame_time : np.ndarray or list of float, optional
160+
The start times of each frame relative to the beginning of the acquisition.
161+
If None, an error is raised (since BIDS requires FrameTimesStart).
162+
frame_duration : np.ndarray or list of float, optional
163+
The duration of each frame. If None, it is derived by the difference
164+
of consecutive frame_times, defaulting the last frame to match the second-last.
165+
166+
Returns
167+
-------
168+
PET
169+
A PET object storing the data, metadata, and any optional mask.
170+
171+
Raises
172+
------
173+
RuntimeError
174+
If `frame_time` is not provided (BIDS requires it).
175+
"""
147176
filename = Path(filename)
148-
if filename.name.endswith(".h5"):
149-
return PET.from_filename(filename)
150-
151-
img = nb.load(filename)
152-
retval = PET(
153-
dataobj=img.get_fdata(dtype="float32"),
154-
affine=img.affine,
155-
)
177+
if filename.suffix == ".h5":
178+
# Load from HDF5
179+
pet_obj = PET.from_filename(filename)
180+
else:
181+
# Load from NIfTI
182+
img = nb.load(str(filename))
183+
data = img.get_fdata(dtype=np.float32)
184+
pet_obj = PET(
185+
dataobj=data,
186+
affine=img.affine,
187+
)
156188

157-
if frame_time is None:
189+
# Verify the user provided frame_time if not already in the PET object
190+
if pet_obj.frame_time is None and frame_time is None:
158191
raise RuntimeError(
159-
"Start time of frames is mandatory (see https://bids-specification.readthedocs.io/"
160-
"en/stable/glossary.html#objects.metadata.FrameTimesStart)"
192+
"The `frame_time` is mandatory for PET data to comply with BIDS. "
193+
"See https://bids-specification.readthedocs.io for details."
161194
)
162195

163-
frame_time = np.array(frame_time, dtype="float32") - frame_time[0]
164-
if frame_duration is None:
165-
frame_duration = np.diff(frame_time)
166-
if len(frame_duration) == (retval.dataobj.shape[-1] - 1):
167-
frame_duration = np.append(frame_duration, frame_duration[-1])
168-
169-
retval.total_duration = frame_time[-1] + frame_duration[-1]
170-
retval.frame_time = frame_time + 0.5 * np.array(frame_duration, dtype="float32")
171-
172-
assert len(retval.frame_time) == retval.dataobj.shape[-1]
196+
# If the user supplied new values, set them
197+
if frame_time is not None:
198+
# Convert to a float32 numpy array and zero out the earliest time
199+
frame_time_arr = np.array(frame_time, dtype=np.float32)
200+
frame_time_arr -= frame_time_arr[0]
201+
pet_obj.frame_time = frame_time_arr
173202

174-
if brainmask_file:
175-
mask = nb.load(brainmask_file)
176-
retval.brainmask = np.asanyarray(mask.dataobj)
177-
178-
return retval
203+
# If the user doesn't provide frame_duration, we derive it:
204+
if frame_duration is None:
205+
frame_time_arr = pet_obj.frame_time
206+
# If shape is e.g. (N,), then we can do
207+
durations = np.diff(frame_time_arr)
208+
if len(durations) == (len(frame_time_arr) - 1):
209+
durations = np.append(durations, durations[-1]) # last frame same as second-last
210+
else:
211+
durations = np.array(frame_duration, dtype=np.float32)
212+
213+
# Set total_duration and shift frame_time to the midpoint
214+
pet_obj.total_duration = float(frame_time_arr[-1] + durations[-1])
215+
pet_obj.frame_time = frame_time_arr + 0.5 * durations
216+
217+
# If a brain mask is provided, load and attach
218+
if brainmask_file is not None:
219+
mask_img = nb.load(str(brainmask_file))
220+
pet_obj.brainmask = np.asanyarray(mask_img.dataobj, dtype=bool)
221+
222+
return pet_obj

0 commit comments

Comments
 (0)