2727import json
2828from collections import namedtuple
2929from pathlib import Path
30+ from typing import Any , Tuple
3031
3132import attrs
3233import h5py
3334import nibabel as nb
3435import numpy as np
3536from nibabel .spatialimages import SpatialImage
3637from nitransforms .linear import Affine
38+ from typing_extensions import Self
3739
3840from nifreeze .data .base import BaseDataset , _cmp , _data_repr
3941from nifreeze .utils .ndimage import load_api
@@ -79,7 +81,9 @@ def __getitem__(
7981 """
8082 return super ().__getitem__ (idx )
8183
82- def lofo_split (self , index ):
84+ def lofo_split (
85+ self , index : int
86+ ) -> Tuple [Tuple [np .ndarray , np .ndarray | None ], Tuple [np .ndarray , np .ndarray | None ]]:
8387 """
8488 Leave-one-frame-out (LOFO) for PET data.
8589
@@ -118,11 +122,10 @@ def lofo_split(self, index):
118122
119123 return (train_data , train_timings ), (test_data , test_timing )
120124
121- def set_transform (self , index , affine , order = 3 ) :
125+ def set_transform (self , index : int , affine : np . ndarray , order : int = 3 ) -> None :
122126 """Set an affine, and update data object and gradients."""
123- reference = namedtuple ("ImageGrid" , ("shape" , "affine" ))(
124- shape = self .dataobj .shape [:3 ], affine = self .affine
125- )
127+ ImageGrid = namedtuple ("ImageGrid" , ("shape" , "affine" ))
128+ reference = ImageGrid (self .dataobj .shape [:3 ], self .affine )
126129 xform = Affine (matrix = affine , reference = reference )
127130
128131 if not Path (self ._filepath ).exists ():
@@ -147,7 +150,9 @@ def set_transform(self, index, affine, order=3):
147150
148151 self .motion_affines [index ] = xform
149152
150- def to_filename (self , filename , compression = None , compression_opts = None ):
153+ def to_filename (
154+ self , filename : Path | str , compression : str | None = None , compression_opts : Any = None
155+ ) -> None :
151156 """Write an HDF5 file to disk."""
152157 filename = Path (filename )
153158 if not filename .name .endswith (".h5" ):
@@ -178,21 +183,23 @@ def to_nifti(self, filename, *_):
178183 nii .to_filename (filename )
179184
180185 @classmethod
181- def from_filename (cls , filename ) :
186+ def from_filename (cls , filename : Path | str ) -> Self :
182187 """Read an HDF5 file from disk."""
183188 with h5py .File (filename , "r" ) as in_file :
184189 root = in_file ["/0" ]
185190 data = {k : np .asanyarray (v ) for k , v in root .items () if not k .startswith ("_" )}
186191 return cls (** data )
187192
188193 @classmethod
189- def load (cls , filename , json_file , brainmask_file = None ):
194+ def load (
195+ cls , filename : Path | str , json_file : Path | str , brainmask_file : Path | str | None = None
196+ ) -> Self :
190197 """Load PET data."""
191198 filename = Path (filename )
192199 if filename .name .endswith (".h5" ):
193200 return cls .from_filename (filename )
194201
195- img = nb . load (filename )
202+ img = load_api (filename , SpatialImage )
196203 retval = cls (
197204 dataobj = img .get_fdata (dtype = "float32" ),
198205 affine = img .affine ,
@@ -212,7 +219,7 @@ def load(cls, filename, json_file, brainmask_file=None):
212219 assert len (retval .midframe ) == retval .dataobj .shape [- 1 ]
213220
214221 if brainmask_file :
215- mask = nb . load (brainmask_file )
222+ mask = load_api (brainmask_file , SpatialImage )
216223 retval .brainmask = np .asanyarray (mask .dataobj )
217224
218225 return retval
0 commit comments