2727import json
2828from collections import namedtuple
2929from pathlib import Path
30- from typing import Callable
30+ from typing import Any , Callable
3131
3232import attrs
3333import h5py
3434import nibabel as nb
3535import numpy as np
3636from nibabel .spatialimages import SpatialImage
3737from nitransforms .linear import Affine
38+ from typing_extensions import Self
3839
3940from nifreeze .data .base import BaseDataset , _cmp , _data_repr
4041from nifreeze .utils .ndimage import load_api
@@ -123,11 +124,10 @@ def lofo_split(self, index):
123124
124125 return (train_data , train_timings ), (test_data , test_timing )
125126
126- def set_transform (self , index , affine , order = 3 ) :
127+ def set_transform (self , index : int , affine : np . ndarray , order : int = 3 ) -> None :
127128 """Set an affine, and update data object and gradients."""
128- reference = namedtuple ("ImageGrid" , ("shape" , "affine" ))(
129- shape = self .dataobj .shape [:3 ], affine = self .affine
130- )
129+ ImageGrid = namedtuple ("ImageGrid" , ("shape" , "affine" ))
130+ reference = ImageGrid (self .dataobj .shape [:3 ], self .affine )
131131 xform = Affine (matrix = affine , reference = reference )
132132
133133 if not Path (self ._filepath ).exists ():
@@ -152,7 +152,9 @@ def set_transform(self, index, affine, order=3):
152152
153153 self .motion_affines [index ] = xform
154154
155- def to_filename (self , filename , compression = None , compression_opts = None ):
155+ def to_filename (
156+ self , filename : Path | str , compression : str | None = None , compression_opts : Any = None
157+ ) -> None :
156158 """Write an HDF5 file to disk."""
157159 filename = Path (filename )
158160 if not filename .name .endswith (".h5" ):
@@ -183,21 +185,23 @@ def to_nifti(self, filename, *_):
183185 nii .to_filename (filename )
184186
185187 @classmethod
186- def from_filename (cls , filename ) :
188+ def from_filename (cls , filename : Path | str ) -> Self :
187189 """Read an HDF5 file from disk."""
188190 with h5py .File (filename , "r" ) as in_file :
189191 root = in_file ["/0" ]
190192 data = {k : np .asanyarray (v ) for k , v in root .items () if not k .startswith ("_" )}
191193 return cls (** data )
192194
193195 @classmethod
194- def load (cls , filename , json_file , brainmask_file = None ):
196+ def load (
197+ cls , filename : Path | str , json_file : Path | str , brainmask_file : Path | str | None = None
198+ ) -> Self :
195199 """Load PET data."""
196200 filename = Path (filename )
197201 if filename .name .endswith (".h5" ):
198202 return cls .from_filename (filename )
199203
200- img = nb . load (filename )
204+ img = load_api (filename , SpatialImage )
201205 retval = cls (
202206 dataobj = img .get_fdata (dtype = "float32" ),
203207 affine = img .affine ,
@@ -217,7 +221,7 @@ def load(cls, filename, json_file, brainmask_file=None):
217221 assert len (retval .midframe ) == retval .dataobj .shape [- 1 ]
218222
219223 if brainmask_file :
220- mask = nb . load (brainmask_file )
224+ mask = load_api (brainmask_file , SpatialImage )
221225 retval .brainmask = np .asanyarray (mask .dataobj )
222226
223227 return retval
0 commit comments