2727from collections import namedtuple
2828from pathlib import Path
2929from tempfile import mkdtemp
30- from typing import Any
30+ from typing import Any , Generic , TypeVarTuple
3131
3232import attr
3333import h5py
3434import nibabel as nb
3535import numpy as np
36+ from nibabel .spatialimages import SpatialHeader , SpatialImage
3637from nitransforms .linear import Affine
3738
39+ from nifreeze .utils .ndimage import load_api
40+
3841NFDH5_EXT = ".h5"
3942
4043
44+ Ts = TypeVarTuple ("Ts" )
45+
46+
4147def _data_repr (value : np .ndarray | None ) -> str :
4248 if value is None :
4349 return "None"
@@ -52,7 +58,7 @@ def _cmp(lh: Any, rh: Any) -> bool:
5258
5359
5460@attr .s (slots = True )
55- class BaseDataset :
61+ class BaseDataset ( Generic [ * Ts ]) :
5662 """
5763 Base dataset representation structure.
5864
@@ -68,15 +74,15 @@ class BaseDataset:
6874
6975 """
7076
71- dataobj = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
77+ dataobj : np . ndarray = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
7278 """A :obj:`~numpy.ndarray` object for the data array."""
73- affine = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
79+ affine : np . ndarray = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
7480 """Best affine for RAS-to-voxel conversion of coordinates (NIfTI header)."""
75- brainmask = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
81+ brainmask : np . ndarray = attr .ib (default = None , repr = _data_repr , eq = attr .cmp_using (eq = _cmp ))
7682 """A boolean ndarray object containing a corresponding brainmask."""
77- motion_affines = attr .ib (default = None , eq = attr .cmp_using (eq = _cmp ))
83+ motion_affines : np . ndarray = attr .ib (default = None , eq = attr .cmp_using (eq = _cmp ))
7884 """List of :obj:`~nitransforms.linear.Affine` realigning the dataset."""
79- datahdr = attr .ib (default = None )
85+ datahdr : SpatialHeader = attr .ib (default = None )
8086 """A :obj:`~nibabel.spatialimages.SpatialHeader` header corresponding to the data."""
8187
8288 _filepath = attr .ib (
@@ -93,9 +99,13 @@ def __len__(self) -> int:
9399
94100 return self .dataobj .shape [- 1 ]
95101
102+ def _getextra (self , idx : int | slice | tuple | np .ndarray ) -> tuple [* Ts ]:
103+ # PY312: Default values for TypeVarTuples are not yet supported
104+ return () # type: ignore[return-value]
105+
96106 def __getitem__ (
97107 self , idx : int | slice | tuple | np .ndarray
98- ) -> tuple [np .ndarray , np .ndarray | None ]:
108+ ) -> tuple [np .ndarray , np .ndarray | None , * Ts ]:
99109 """
100110 Returns volume(s) and corresponding affine(s) through fancy indexing.
101111
@@ -118,7 +128,7 @@ def __getitem__(
118128 raise ValueError ("No data available (dataobj is None)." )
119129
120130 affine = self .motion_affines [idx ] if self .motion_affines is not None else None
121- return self .dataobj [..., idx ], affine
131+ return self .dataobj [..., idx ], affine , * self . _getextra ( idx )
122132
123133 @classmethod
124134 def from_filename (cls , filename : Path | str ) -> BaseDataset :
@@ -159,9 +169,8 @@ def set_transform(self, index: int, affine: np.ndarray, order: int = 3) -> None:
159169 The order of the spline interpolation.
160170
161171 """
162- reference = namedtuple ("ImageGrid" , ("shape" , "affine" ))(
163- shape = self .dataobj .shape [:3 ], affine = self .affine
164- )
172+ ImageGrid = namedtuple ("ImageGrid" , ("shape" , "affine" ))
173+ reference = ImageGrid (shape = self .dataobj .shape [:3 ], affine = self .affine )
165174
166175 xform = Affine (matrix = affine , reference = reference )
167176
@@ -227,7 +236,7 @@ def to_filename(
227236 compression_opts = compression_opts ,
228237 )
229238
230- def to_nifti (self , filename : Path ) -> None :
239+ def to_nifti (self , filename : Path | str ) -> None :
231240 """
232241 Write a NIfTI file to disk.
233242
@@ -247,7 +256,7 @@ def load(
247256 filename : Path | str ,
248257 brainmask_file : Path | str | None = None ,
249258 motion_file : Path | str | None = None ,
250- ) -> BaseDataset :
259+ ) -> BaseDataset [()] :
251260 """
252261 Load 4D data from a filename or an HDF5 file.
253262
@@ -279,11 +288,11 @@ def load(
279288 if filename .name .endswith (NFDH5_EXT ):
280289 return BaseDataset .from_filename (filename )
281290
282- img = nb . load (filename )
283- retval = BaseDataset (dataobj = img .dataobj , affine = img .affine )
291+ img = load_api (filename , SpatialImage )
292+ retval : BaseDataset [()] = BaseDataset (dataobj = np . asanyarray ( img .dataobj ) , affine = img .affine )
284293
285294 if brainmask_file :
286- mask = nb . load (brainmask_file )
295+ mask = load_api (brainmask_file , SpatialImage )
287296 retval .brainmask = np .asanyarray (mask .dataobj )
288297
289298 return retval
0 commit comments