Skip to content

Commit 6bcf97e

Browse files
committed
Merge remote-tracking branch 'upstream/main'
2 parents ff982d4 + 11cd3ec commit 6bcf97e

File tree

24 files changed

+283
-229
lines changed

24 files changed

+283
-229
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ jobs:
4242
pip install -e .[dev]
4343
- name: Linters
4444
run: |
45-
black --check src tests
46-
ruff check .
45+
ruff check src
46+
ruff format --check src
4747
- name: Annotate locations with typos
4848
if: always()
4949
uses: codespell-project/codespell-problem-matcher@v1

examples/anatomical/example_fov_select.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@
3737

3838

3939
# the center and size target_res are all in millimeters.
40-
fov_handlers = FOVHandler(
40+
fov_handler = FOVHandler(
4141
center=(90, 110, 100),
4242
size=(192, 192, 128),
4343
angles=(5, 0, 0),
4444
target_res=(2.0, 2.0, 2.0),
4545
)
4646

47-
new_phantom = fov_handlers.get_static(phantom, sim_conf)
47+
new_phantom = fov_handler.get_static(phantom, sim_conf)
4848

4949
# %%
5050
new_contrast = get_ideal_phantom(new_phantom, sim_conf)

pyproject.toml

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -131,17 +131,15 @@ convention = "numpy"
131131
"tests/test_*.py" = ["D", "ANN"]
132132

133133

134-
[tool.pylsp-mypy]
135-
enabled = true
136-
live_mode = true
137-
strict = true
138-
#overrides = ["--ignore-missing-imports"]
134+
[tool.basedpyright]
139135

140-
[tool.mypy]
141-
ignore_missing_imports=true
142-
exclude = ["examples/"]
143-
#overrides = ["--ignore-missing-imports"]
136+
typeCheckingMode = "off"
144137

145138
[tool.codespell]
146139
ignore-words-list = ["TE","fpr"]
147140
skip = ["docs/generated", "*.ipynb"]
141+
142+
[dependency-groups]
143+
dev = [
144+
"ipykernel>=6.29.5",
145+
]

src/snake/_meta.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
T = TypeVar("T")
1919

20+
ThreeInts = tuple[int, int, int]
21+
ThreeFloats = tuple[float, float, float]
22+
2023

2124
def make_log_property(dunder_name: str) -> Callable:
2225
"""Create a property logger."""
@@ -155,7 +158,6 @@ class ENVCONFIG(metaclass=Singleton):
155158

156159
@classmethod
157160
def __getitem__(cls, key: str) -> Any:
158-
159161
if key in os.environ:
160162
return os.environ[key]
161163
return getattr(cls, key)

src/snake/core/engine/base.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from collections.abc import Mapping, Sequence
1010
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
1111
from multiprocessing.managers import SharedMemoryManager
12+
from pathlib import Path
1213
from tempfile import TemporaryDirectory
1314
from typing import Any, ClassVar
1415

@@ -17,21 +18,26 @@
1718
from numpy.typing import NDArray
1819
from tqdm.auto import tqdm
1920

21+
from typing_extensions import dataclass_transform
22+
2023
from snake._meta import EnvConfig, MetaDCRegister, batched
2124

2225
from ...mrd_utils import MRDLoader, make_base_mrd
26+
from ..handlers import AbstractHandler, HandlerList
2327
from ..parallel import ArrayProps
2428
from ..phantom import DynamicData, Phantom, PropTissueEnum
25-
from ..simulation import SimConfig
2629
from ..sampling import BaseSampler
27-
from ..handlers import AbstractHandler, HandlerList
30+
from ..simulation import SimConfig
2831
from .utils import get_ideal_phantom, get_noise
2932

33+
AnyPath = str | Path
34+
3035

36+
@dataclass_transform(kw_only_default=True)
3137
class MetaEngine(MetaDCRegister):
3238
"""MetaClass for engines."""
3339

34-
dunder_name = "engine"
40+
dunder_name: ClassVar[str] = "engine"
3541

3642

3743
class BaseAcquisitionEngine(metaclass=MetaEngine):
@@ -88,7 +94,6 @@ def _job_model_T2s(
8894
dyn_datas: list[DynamicData],
8995
sim_conf: SimConfig,
9096
trajectories: NDArray, # (Chunksize, N, 3)
91-
smaps: NDArray,
9297
*args: Any,
9398
**kwargs: Any,
9499
) -> NDArray:
@@ -100,7 +105,6 @@ def _job_model_simple(
100105
dyn_datas: list[DynamicData],
101106
sim_conf: SimConfig,
102107
trajectories: NDArray, # (Chunksize, N, 3)
103-
smaps: NDArray,
104108
*args: Any,
105109
**kwargs: Any,
106110
) -> NDArray:
@@ -113,11 +117,11 @@ def _write_chunk_data(
113117

114118
def _acquire_ksp_job(
115119
self,
116-
filename: os.PathLike,
120+
filename: AnyPath,
117121
chunk: Sequence[int],
118122
tmp_dir: str,
119123
shared_phantom_props: (
120-
tuple[str, ArrayProps, ArrayProps, ArrayProps] | None
124+
tuple[str, ArrayProps, ArrayProps, ArrayProps, ArrayProps | None] | None
121125
) = None,
122126
**kwargs: Mapping[str, Any],
123127
) -> str:
@@ -144,28 +148,24 @@ def _acquire_ksp_job(
144148
trajs = self._job_trajectories(data_loader, hdr, sim_conf, chunk)
145149

146150
_job_model = getattr(self, f"_job_model_{self.model}")
147-
smaps = None
148-
if sim_conf.hardware.n_coils > 1:
149-
smaps = data_loader.get_smaps()
150151
if shared_phantom_props is None:
151152
phantom = data_loader.get_phantom()
152-
ksp = _job_model(phantom, ddatas, sim_conf, trajs, smaps, **kwargs)
153+
ksp = _job_model(phantom, ddatas, sim_conf, trajs, **kwargs)
153154
else:
154155
with Phantom.from_shared_memory(*shared_phantom_props) as phantom:
155-
ksp = _job_model(phantom, ddatas, sim_conf, trajs, smaps, **kwargs)
156+
ksp = _job_model(phantom, ddatas, sim_conf, trajs, **kwargs)
156157

157158
chunk_file = os.path.join(tmp_dir, f"partial_{chunk[0]}-{chunk[-1]}.npy")
158159
np.save(chunk_file, ksp)
159160
return chunk_file
160161

161162
def __call__(
162163
self,
163-
filename: os.PathLike,
164+
filename: AnyPath,
164165
sampler: BaseSampler,
165166
phantom: Phantom,
166167
sim_conf: SimConfig,
167168
handlers: list[AbstractHandler] | HandlerList | None = None,
168-
smaps: NDArray | None = None,
169169
coil_cov: NDArray | None = None,
170170
worker_chunk_size: int = 0,
171171
n_workers: int = 0,
@@ -175,7 +175,7 @@ def __call__(
175175
176176
Parameters
177177
----------
178-
filename : os.PathLike
178+
filename : AnyPath
179179
The path to the MRD file.
180180
sampler : BaseSampler
181181
The sampler to use.
@@ -185,8 +185,6 @@ def __call__(
185185
The simulation configuration.
186186
handlers : list[AbstractHandler] | HandlerList | None, optional
187187
The handlers to use, by default None.
188-
smaps : NDArray | None, optional
189-
The sensitivity maps, by default None.
190188
coil_cov : NDArray | None, optional
191189
The coil covariance matrix, by default None.
192190
worker_chunk_size : int, optional
@@ -215,7 +213,6 @@ def __call__(
215213
phantom,
216214
sim_conf,
217215
handlers,
218-
smaps,
219216
coil_cov,
220217
self.model,
221218
self.slice_2d,

src/snake/core/engine/cartesian.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def _job_model_T2s(
6363
dyn_datas: list[DynamicData],
6464
sim_conf: SimConfig,
6565
trajectories: NDArray, # (Chunksize, N, 3)
66-
smaps: NDArray,
6766
slice_2d: bool = False,
6867
) -> np.ndarray:
6968
"""Acquire k-space data. With T2s decay."""
@@ -107,18 +106,20 @@ def _job_model_T2s(
107106
slice_location = flat_epi[0, 0] # FIXME: the slice is always axial.
108107
flat_epi = flat_epi[:, 1:]
109108
phantom_slice = phantom_state[:, slice_location]
110-
if smaps is None:
109+
if phantom.smaps is None:
111110
phantom_slice = phantom_slice[:, None, ...]
112111
else:
113-
smaps_ = smaps[:, slice_location]
112+
smaps_ = phantom.smaps[:, slice_location]
114113
phantom_slice = phantom_slice[:, None, ...] * smaps_
115114

116115
ksp = fft(phantom_slice, axis=(-2, -1))
117116
else:
118-
if smaps is None:
117+
if phantom.smaps is None:
119118
ksp = fft(phantom_state[:, None, ...], axis=(-3, -2, -1))
120119
else:
121-
ksp = fft(phantom_state[:, None, ...] * smaps, axis=(-3, -2, -1))
120+
ksp = fft(
121+
phantom_state[:, None, ...] * phantom.smaps, axis=(-3, -2, -1)
122+
)
122123

123124
for c in range(sim_conf.hardware.n_coils):
124125
ksp_coil_sum = np.zeros(
@@ -135,7 +136,6 @@ def _job_model_simple(
135136
dyn_datas: list[DynamicData],
136137
sim_conf: SimConfig,
137138
trajectories: NDArray, # (Chunksize, N, 3)
138-
smaps: NDArray,
139139
slice_2d: bool = False,
140140
) -> np.ndarray:
141141
"""Acquire k-space data. No T2s decay."""
@@ -157,17 +157,19 @@ def _job_model_simple(
157157
slice_location = flat_epi[0, 0] # FIXME: the slice is always axial.
158158
flat_epi = flat_epi[:, 1:] # Reduced to 2D.
159159
phantom_slice = phantom_state[slice_location]
160-
if smaps is None:
160+
if phantom.smaps is None:
161161
phantom_slice = phantom_slice[None, ...]
162162
else:
163-
smaps_ = smaps[:, slice_location]
163+
smaps_ = phantom.smaps[:, slice_location]
164164
phantom_slice = phantom_slice[None, ...] * smaps_
165165
ksp = fft(phantom_slice, axis=(-2, -1))
166166
else:
167-
if smaps is None:
167+
if phantom.smaps is None:
168168
ksp = fft(phantom_state[None, ...], axis=(-3, -2, -1))
169169
else:
170-
ksp = fft(phantom_state[None, ...] * smaps, axis=(-3, -2, -1))
170+
ksp = fft(
171+
phantom_state[None, ...] * phantom.smaps, axis=(-3, -2, -1)
172+
)
171173
for c in range(sim_conf.hardware.n_coils):
172174
ksp_coil = ksp[c]
173175
a = ksp_coil[tuple(flat_epi.T)]
@@ -247,7 +249,6 @@ def _job_model_T2s(
247249
dyn_datas: list[DynamicData],
248250
sim_conf: SimConfig,
249251
trajectories: NDArray, # (Chunksize, N, 3)
250-
smaps: NDArray,
251252
) -> np.ndarray:
252253
"""Acquire k-space data. With T2s decay."""
253254
readout_length = trajectories.shape[-2]
@@ -299,10 +300,12 @@ def _job_model_T2s(
299300
* frame_phantom.masks
300301
)
301302

302-
if smaps is None:
303+
if phantom.smaps is None:
303304
ksp = fft(phantom_state[:, None, ...], axis=(-3, -2, -1))
304305
else:
305-
ksp = fft(phantom_state[:, None, ...] * smaps, axis=(-3, -2, -1))
306+
ksp = fft(
307+
phantom_state[:, None, ...] * phantom.smaps, axis=(-3, -2, -1)
308+
)
306309
flat_evi = evi.reshape(-1, 3)
307310
for c in range(sim_conf.hardware.n_coils):
308311
ksp_coil_sum = np.zeros(
@@ -322,7 +325,6 @@ def _job_model_simple(
322325
dyn_datas: list[DynamicData],
323326
sim_conf: SimConfig,
324327
trajectories: NDArray, # (Chunksize, N, 3)
325-
smaps: NDArray,
326328
) -> np.ndarray:
327329
"""Acquire k-space data. No T2s decay."""
328330
final_ksp = np.zeros(
@@ -351,10 +353,10 @@ def _job_model_simple(
351353
* frame_phantom.masks,
352354
axis=0,
353355
)
354-
if smaps is None:
356+
if phantom.smaps is None:
355357
ksp = fft(phantom_state[None, ...], axis=(-3, -2, -1))
356358
else:
357-
ksp = fft(phantom_state[None, ...] * smaps, axis=(-3, -2, -1))
359+
ksp = fft(phantom_state[None, ...] * phantom.smaps, axis=(-3, -2, -1))
358360
flat_epi = epi_2d.reshape(-1, 3)
359361
for c in range(sim_conf.hardware.n_coils):
360362
ksp_coil = ksp[c]

src/snake/core/engine/nufft.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _job_trajectories(
2828
dataset: mrd.Dataset,
2929
hdr: mrd.xsd.ismrmrdHeader,
3030
sim_conf: SimConfig,
31-
shot_idx: Sequence[int],
31+
shot_idx: Sequence[int] | int,
3232
) -> NDArray:
3333
"""Get Non Cartesian trajectories from the dataset.
3434
@@ -72,10 +72,12 @@ def _init_model_nufft(
7272
if slice_2d:
7373
shape_ = sim_conf.shape[:-1]
7474
if smaps is not None:
75-
smaps_ = smaps[..., 0]
75+
smaps_ = np.ascontiguousarray(
76+
smaps[..., 0]
77+
) # will be updated in the loop
7678

7779
nufft = get_operator(backend)(
78-
samples, # dummy samples locs
80+
samples, # will be updated in the loop
7981
shape=shape_,
8082
n_coils=n_coils,
8183
smaps=smaps_,
@@ -91,7 +93,6 @@ def _job_model_T2s(
9193
dyn_datas: list[DynamicData],
9294
sim_conf: SimConfig,
9395
trajectories: NDArray,
94-
smaps: NDArray,
9596
nufft_backend: str,
9697
slice_2d: bool = False,
9798
) -> np.ndarray:
@@ -105,7 +106,7 @@ def _job_model_T2s(
105106
nufft = NufftAcquisitionEngine._init_model_nufft(
106107
trajectories[0],
107108
sim_conf,
108-
smaps,
109+
phantom.smaps,
109110
backend=nufft_backend,
110111
slice_2d=slice_2d,
111112
)
@@ -121,8 +122,8 @@ def _job_model_T2s(
121122
if slice_2d:
122123
slice_loc = round((traj[0, -1] + 0.5) * sim_conf.shape[-1])
123124
nufft.samples = traj[:, :2]
124-
if smaps is not None:
125-
nufft.smaps = smaps[..., slice_loc]
125+
if phantom.smaps is not None:
126+
nufft.smaps = np.ascontiguousarray(phantom.smaps[..., slice_loc])
126127
phantom_state = phantom_state[:, None, ..., slice_loc]
127128
else:
128129
phantom_state = phantom_state[:, None, ...]
@@ -140,7 +141,6 @@ def _job_model_simple(
140141
dyn_datas: list[DynamicData],
141142
sim_conf: SimConfig,
142143
trajectories: NDArray,
143-
smaps: NDArray,
144144
nufft_backend: str,
145145
slice_2d: bool = False,
146146
) -> np.ndarray:
@@ -153,7 +153,7 @@ def _job_model_simple(
153153
nufft = NufftAcquisitionEngine._init_model_nufft(
154154
trajectories[0],
155155
sim_conf,
156-
smaps,
156+
phantom.smaps,
157157
backend=nufft_backend,
158158
slice_2d=slice_2d,
159159
)
@@ -165,8 +165,8 @@ def _job_model_simple(
165165
if slice_2d:
166166
slice_loc = int((traj[0, -1] + 0.5) * sim_conf.shape[-1])
167167
nufft.samples = traj[:, :2]
168-
if smaps is not None:
169-
nufft.smaps = smaps[..., slice_loc]
168+
if phantom.smaps is not None:
169+
nufft.smaps = phantom.smaps[..., slice_loc]
170170
phantom_state = phantom_state[None, ..., slice_loc]
171171
else:
172172
nufft.samples = traj

0 commit comments

Comments
 (0)