1515from .._meta import LogMixin
1616
1717if TYPE_CHECKING :
18+ from _typeshed import AnyPath
1819 from ..core import Phantom , DynamicData
1920 from ..core import SimConfig
2021
2324log = logging .getLogger (__name__ )
2425
2526
26- def read_mrd_header (filename : os . PathLike | mrd .Dataset ) -> mrd .xsd .ismrmrdHeader :
27+ def read_mrd_header (filename : AnyPath | mrd .Dataset ) -> mrd .xsd .ismrmrdHeader :
2728 """Read the header of the MRD file."""
2829 if isinstance (filename , mrd .Dataset ):
2930 dataset = filename
@@ -50,7 +51,7 @@ class MRDLoader(LogMixin):
5051
5152 def __init__ (
5253 self ,
53- filename : os . PathLike ,
54+ filename : AnyPath ,
5455 dataset_name : str = "dataset" ,
5556 writeable : bool = False ,
5657 swmr : bool = False ,
@@ -85,7 +86,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
8586
8687 def iter_frames (
8788 self , start : int | None = None , stop : int | None = None , step : int | None = None
88- ) -> Generator [tuple [int , NDArray , NDArray ], None , None ]:
89+ ) -> Generator [tuple [int , NDArray [ np . float32 ] , NDArray [ np . complex64 ] ], None , None ]:
8990 """Iterate over kspace frames of the dataset."""
9091 if start is None :
9192 start = 0
@@ -97,7 +98,7 @@ def iter_frames(
9798 for i in np .arange (start , stop , step ):
9899 yield i , * self .get_kspace_frame (i )
99100
100- def get_kspace_frame (self , idx : int ) -> tuple [NDArray , NDArray ]:
101+ def get_kspace_frame (self , idx : int ) -> tuple [NDArray [ np . float32 ] , NDArray [ np . complex64 ] ]:
101102 """Get k-space frame trajectory/mask and data."""
102103 raise NotImplementedError ()
103104
@@ -275,15 +276,15 @@ def get_sim_conf(self) -> SimConfig:
275276 """Parse the sim config."""
276277 return parse_sim_conf (self .header )
277278
278- def _get_image_data (self , name : str , idx : int = 0 ) -> NDArray | None :
279+ def _get_image_data (self , name : str , idx : int = 0 ) -> NDArray [ np . complex64 ] | None :
279280 try :
280- image = self ._read_image (name , idx ).data
281+ image = self ._read_image (name , idx ).data . astype ( np . complex64 )
281282 except LookupError :
282283 log .warning (f"No { name } found in the dataset." )
283284 return None
284285 return image
285286
286- def get_smaps (self ) -> NDArray | None :
287+ def get_smaps (self ) -> NDArray [ np . complex64 ] | None :
287288 """Load the sensitivity maps from the dataset."""
288289 return self ._get_image_data ("smaps" )
289290
@@ -344,7 +345,7 @@ class NonCartesianFrameDataLoader(MRDLoader):
344345
345346 def get_kspace_frame (
346347 self , idx : int , shot_dim : bool = False
347- ) -> tuple [np .ndarray , np .ndarray ]:
348+ ) -> tuple [NDArray [ np .float32 ], NDArray [ np .complex64 ] ]:
348349 """Get the k-space frame and the associated trajectory.
349350
350351 Parameters
0 commit comments