Skip to content

Commit 3bae00b

Browse files
committed
feat: organize import and typing
1 parent 8771b2e commit 3bae00b

File tree

10 files changed

+61
-34
lines changed

10 files changed

+61
-34
lines changed

src/snake/core/engine/base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections.abc import Mapping, Sequence
1010
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
1111
from multiprocessing.managers import SharedMemoryManager
12+
from pathlib import Path
1213
from tempfile import TemporaryDirectory
1314
from typing import Any, ClassVar
1415

@@ -17,21 +18,25 @@
1718
from numpy.typing import NDArray
1819
from tqdm.auto import tqdm
1920

21+
from typing_extensions import dataclass_transform
22+
2023
from snake._meta import EnvConfig, MetaDCRegister, batched
2124

2225
from ...mrd_utils import MRDLoader, make_base_mrd
26+
from ..handlers import AbstractHandler, HandlerList
2327
from ..parallel import ArrayProps
2428
from ..phantom import DynamicData, Phantom, PropTissueEnum
25-
from ..simulation import SimConfig
2629
from ..sampling import BaseSampler
27-
from ..handlers import AbstractHandler, HandlerList
30+
from ..simulation import SimConfig
2831
from .utils import get_ideal_phantom, get_noise
2932

33+
AnyPath = str | Path
3034

35+
@dataclass_transform(kw_only_default=True)
3136
class MetaEngine(MetaDCRegister):
3237
"""MetaClass for engines."""
3338

34-
dunder_name = "engine"
39+
dunder_name: ClassVar[str] = "engine"
3540

3641

3742
class BaseAcquisitionEngine(metaclass=MetaEngine):
@@ -113,7 +118,7 @@ def _write_chunk_data(
113118

114119
def _acquire_ksp_job(
115120
self,
116-
filename: os.PathLike,
121+
filename: AnyPath,
117122
chunk: Sequence[int],
118123
tmp_dir: str,
119124
shared_phantom_props: (
@@ -160,7 +165,7 @@ def _acquire_ksp_job(
160165

161166
def __call__(
162167
self,
163-
filename: os.PathLike,
168+
filename: AnyPath,
164169
sampler: BaseSampler,
165170
phantom: Phantom,
166171
sim_conf: SimConfig,
@@ -175,7 +180,7 @@ def __call__(
175180
176181
Parameters
177182
----------
178-
filename : os.PathLike
183+
filename : AnyPath
179184
The path to the MRD file.
180185
sampler : BaseSampler
181186
The sampler to use.

src/snake/core/handlers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
from .activations import BlockActivationHandler
55
from .noise import NoiseHandler
66
from .motion import RandomMotionImageHandler
7+
from .fov import FOVHandler
78

89
__all__ = [
910
"AbstractHandler",
1011
"HandlerList",
1112
"get_handler",
1213
"H",
14+
"FOVHandler",
1315
"BlockActivationHandler",
1416
"NoiseHandler",
1517
"RandomMotionImageHandler",

src/snake/core/handlers/base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
"""Base handler module."""
22

33
from __future__ import annotations
4-
import yaml
4+
55
import dataclasses
66
from collections import UserList
7+
from typing import Any, ClassVar, TypeVar
78

8-
from ..._meta import MetaDCRegister
9-
from typing import ClassVar, TypeVar, Any
9+
import yaml
10+
from typing_extensions import dataclass_transform
1011

12+
from ..._meta import MetaDCRegister
13+
from ..phantom import DynamicData, KspaceDynamicData, Phantom
1114
from ..simulation import SimConfig
12-
from ..phantom import Phantom, DynamicData, KspaceDynamicData
1315

1416
T = TypeVar("T")
1517

16-
18+
@dataclass_transform(kw_only_default=True)
1719
class MetaHandler(MetaDCRegister):
1820
"""MetaClass for Handlers."""
1921

20-
dunder_name = "handler"
22+
dunder_name: ClassVar[str] = "handler"
2123

2224

2325
class AbstractHandler(metaclass=MetaHandler):

src/snake/core/handlers/fov.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
"""Handler for modifying the Field of View of the Phantom."""
1+
"""Handler for modifying the Field of View of the Phantom.
2+
3+
TODO: Add a FOV-motion handler that combines FOV and motion (moving the head is
4+
equivalent to changing the center point of FOV + angles).
5+
6+
"""
27

38
from __future__ import annotations
49
import warnings

src/snake/core/handlers/motion/image.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
"""Motion in the image domain."""
1+
"""Motion in the image domain.
2+
3+
TODO: Add constraints on maximal displacements
4+
TODO: Use FOV displacements
5+
"""
26

37
from copy import deepcopy
48

src/snake/core/sampling/base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44
import logging
55
from typing import ClassVar
6+
from typing_extensions import dataclass_transform
67
from numpy.typing import NDArray
78

89
from snake._meta import MetaDCRegister
@@ -11,6 +12,7 @@
1112
import ismrmrd as mrd
1213

1314

15+
@dataclass_transform(kw_only_default=True) # Required here for pyright to work.
1416
class MetaSampler(MetaDCRegister):
1517
"""MetaClass for Samplers."""
1618

src/snake/core/sampling/samplers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,10 @@ class StackOfSpiralSampler(NonCartesianAcquisitionSampler):
249249

250250
acsz: float | int
251251
accelz: int
252-
orderz: VDSorder = VDSorder.TOP_DOWN
252+
orderz: str | VDSorder = VDSorder.TOP_DOWN
253253
nb_revolutions: int = 10
254254
spiral_name: str = "archimedes"
255-
pdfz: VDSpdf = VDSpdf.GAUSSIAN
255+
pdfz: str | VDSpdf = VDSpdf.GAUSSIAN
256256
constant: bool = False
257257
in_out: bool = True
258258
rotate_angle: AngleRotation = AngleRotation.ZERO

src/snake/mrd_utils/loader.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .._meta import LogMixin
1616

1717
if TYPE_CHECKING:
18+
from _typeshed import AnyPath
1819
from ..core import Phantom, DynamicData
1920
from ..core import SimConfig
2021

@@ -23,7 +24,7 @@
2324
log = logging.getLogger(__name__)
2425

2526

26-
def read_mrd_header(filename: os.PathLike | mrd.Dataset) -> mrd.xsd.ismrmrdHeader:
27+
def read_mrd_header(filename: AnyPath | mrd.Dataset) -> mrd.xsd.ismrmrdHeader:
2728
"""Read the header of the MRD file."""
2829
if isinstance(filename, mrd.Dataset):
2930
dataset = filename
@@ -50,7 +51,7 @@ class MRDLoader(LogMixin):
5051

5152
def __init__(
5253
self,
53-
filename: os.PathLike,
54+
filename: AnyPath,
5455
dataset_name: str = "dataset",
5556
writeable: bool = False,
5657
swmr: bool = False,
@@ -85,7 +86,7 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any):
8586

8687
def iter_frames(
8788
self, start: int | None = None, stop: int | None = None, step: int | None = None
88-
) -> Generator[tuple[int, NDArray, NDArray], None, None]:
89+
) -> Generator[tuple[int, NDArray[np.float32], NDArray[np.complex64]], None, None]:
8990
"""Iterate over kspace frames of the dataset."""
9091
if start is None:
9192
start = 0
@@ -97,7 +98,7 @@ def iter_frames(
9798
for i in np.arange(start, stop, step):
9899
yield i, *self.get_kspace_frame(i)
99100

100-
def get_kspace_frame(self, idx: int) -> tuple[NDArray, NDArray]:
101+
def get_kspace_frame(self, idx: int) -> tuple[NDArray[np.float32], NDArray[np.complex64]]:
101102
"""Get k-space frame trajectory/mask and data."""
102103
raise NotImplementedError()
103104

@@ -275,15 +276,15 @@ def get_sim_conf(self) -> SimConfig:
275276
"""Parse the sim config."""
276277
return parse_sim_conf(self.header)
277278

278-
def _get_image_data(self, name: str, idx: int = 0) -> NDArray | None:
279+
def _get_image_data(self, name: str, idx: int = 0) -> NDArray[np.complex64] | None:
279280
try:
280-
image = self._read_image(name, idx).data
281+
image = self._read_image(name, idx).data.astype(np.complex64)
281282
except LookupError:
282283
log.warning(f"No {name} found in the dataset.")
283284
return None
284285
return image
285286

286-
def get_smaps(self) -> NDArray | None:
287+
def get_smaps(self) -> NDArray[np.complex64] | None:
287288
"""Load the sensitivity maps from the dataset."""
288289
return self._get_image_data("smaps")
289290

@@ -344,7 +345,7 @@ class NonCartesianFrameDataLoader(MRDLoader):
344345

345346
def get_kspace_frame(
346347
self, idx: int, shot_dim: bool = False
347-
) -> tuple[np.ndarray, np.ndarray]:
348+
) -> tuple[NDArray[np.float32], NDArray[np.complex64]]:
348349
"""Get the k-space frame and the associated trajectory.
349350
350351
Parameters

src/snake/toolkit/plotting.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ def plot_frames_activ(
198198

199199

200200
def axis3dcut(
201-
background: NDArray,
202-
z_score: NDArray,
201+
background: NDArray[np.float32],
202+
z_score: NDArray[np.float32] | None,
203203
gt_roi: NDArray | None = None,
204204
width_inches: float = 7,
205205
cbar: bool = True,
@@ -312,17 +312,19 @@ def axis3dcut(
312312
slices_[i],
313313
bbox_[i],
314314
bg_cmap=bg_cmap,
315+
z_thresh=z_thresh,
316+
z_max=z_max
315317
)
316318

317319
if cbar:
318320
cax = type(ax)(fig, ax.get_position(original=True))
319321
cax.set_axes_locator(divider.new_locator(nx=3, ny=0, ny1=-1))
320322
if z_score is not None:
321323
im = ScalarMappable(norm="linear", cmap=get_coolgraywarm())
322-
im.set_clim(-11, 11)
324+
im.set_clim(-z_max, z_max)
323325
matplotlib.colorbar.Colorbar(cax, im, orientation="vertical")
324326
cax.set_ylabel("z-scores", labelpad=-20)
325-
cax.set_yticks(np.concatenate([-np.arange(3, 12, 2), np.arange(3, 12, 2)]))
327+
cax.set_yticks(np.concatenate([-np.arange(z_thresh, z_max+1, 2), np.arange(z_thresh, z_max+1, 2)]))
326328
else:
327329
# use the background image
328330
if vmin_vmax is None:

src/snake/toolkit/reconstructors/base.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
"""Base Class for Reconstructors."""
22

3-
from numpy.typing import NDArray
4-
53
import logging
64
from dataclasses import field
75
from typing import Any, ClassVar
8-
from ..._meta import MetaDCRegister
9-
from snake.mrd_utils import MRDLoader
6+
7+
from numpy.typing import NDArray
8+
from typing_extensions import dataclass_transform
9+
1010
from snake.core.simulation import SimConfig
11+
from snake.mrd_utils import MRDLoader
12+
13+
from ..._meta import MetaDCRegister
1114

1215

16+
@dataclass_transform(kw_only_default=True)
1317
class MetaReconstructor(MetaDCRegister):
1418
"""MetaClass Reconstructor."""
1519

16-
dunder_name = "reconstructor"
20+
dunder_name: ClassVar[str] = "reconstructor"
1721

1822

1923
class BaseReconstructor(metaclass=MetaReconstructor):
@@ -34,7 +38,7 @@ def setup(self, sim_conf: SimConfig) -> None:
3438
"""Set up the reconstructor."""
3539
self.log.info(f"Setup reconstructor {self.__class__.__name__}")
3640

37-
def reconstruct(self, data_loader: MRDLoader, sim_conf: SimConfig) -> NDArray:
41+
def reconstruct(self, data_loader: MRDLoader) -> NDArray:
3842
"""Reconstruct the kspace data to image space."""
3943
raise NotImplementedError
4044

0 commit comments

Comments
 (0)