Skip to content

Commit 9fc2967

Browse files
committed
feat!: smaps is now part of phantom.
This is a breaking change, but for the best: the smaps can now be modified by handlers and are also shared in memory.
1 parent 3bae00b commit 9fc2967

File tree

9 files changed

+98
-88
lines changed

9 files changed

+98
-88
lines changed

src/snake/_meta.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
T = TypeVar("T")
1919

20+
ThreeInts = tuple[int, int, int]
21+
ThreeFloats = tuple[float, float, float]
22+
2023

2124
def make_log_property(dunder_name: str) -> Callable:
2225
"""Create a property logger."""

src/snake/core/engine/base.py

Lines changed: 5 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
from ..simulation import SimConfig
3131
from .utils import get_ideal_phantom, get_noise
3232

33-
AnyPath = str | Path
33+
AnyPath = str | Path
34+
3435

3536
@dataclass_transform(kw_only_default=True)
3637
class MetaEngine(MetaDCRegister):
@@ -93,7 +94,6 @@ def _job_model_T2s(
9394
dyn_datas: list[DynamicData],
9495
sim_conf: SimConfig,
9596
trajectories: NDArray, # (Chunksize, N, 3)
96-
smaps: NDArray,
9797
*args: Any,
9898
**kwargs: Any,
9999
) -> NDArray:
@@ -105,7 +105,6 @@ def _job_model_simple(
105105
dyn_datas: list[DynamicData],
106106
sim_conf: SimConfig,
107107
trajectories: NDArray, # (Chunksize, N, 3)
108-
smaps: NDArray,
109108
*args: Any,
110109
**kwargs: Any,
111110
) -> NDArray:
@@ -122,7 +121,7 @@ def _acquire_ksp_job(
122121
chunk: Sequence[int],
123122
tmp_dir: str,
124123
shared_phantom_props: (
125-
tuple[str, ArrayProps, ArrayProps, ArrayProps] | None
124+
tuple[str, ArrayProps, ArrayProps, ArrayProps, ArrayProps | None] | None
126125
) = None,
127126
**kwargs: Mapping[str, Any],
128127
) -> str:
@@ -149,15 +148,12 @@ def _acquire_ksp_job(
149148
trajs = self._job_trajectories(data_loader, hdr, sim_conf, chunk)
150149

151150
_job_model = getattr(self, f"_job_model_{self.model}")
152-
smaps = None
153-
if sim_conf.hardware.n_coils > 1:
154-
smaps = data_loader.get_smaps()
155151
if shared_phantom_props is None:
156152
phantom = data_loader.get_phantom()
157-
ksp = _job_model(phantom, ddatas, sim_conf, trajs, smaps, **kwargs)
153+
ksp = _job_model(phantom, ddatas, sim_conf, trajs, **kwargs)
158154
else:
159155
with Phantom.from_shared_memory(*shared_phantom_props) as phantom:
160-
ksp = _job_model(phantom, ddatas, sim_conf, trajs, smaps, **kwargs)
156+
ksp = _job_model(phantom, ddatas, sim_conf, trajs, **kwargs)
161157

162158
chunk_file = os.path.join(tmp_dir, f"partial_{chunk[0]}-{chunk[-1]}.npy")
163159
np.save(chunk_file, ksp)
@@ -170,7 +166,6 @@ def __call__(
170166
phantom: Phantom,
171167
sim_conf: SimConfig,
172168
handlers: list[AbstractHandler] | HandlerList | None = None,
173-
smaps: NDArray | None = None,
174169
coil_cov: NDArray | None = None,
175170
worker_chunk_size: int = 0,
176171
n_workers: int = 0,
@@ -190,8 +185,6 @@ def __call__(
190185
The simulation configuration.
191186
handlers : list[AbstractHandler] | HandlerList | None, optional
192187
The handlers to use, by default None.
193-
smaps : NDArray | None, optional
194-
The sensitivity maps, by default None.
195188
coil_cov : NDArray | None, optional
196189
The coil covariance matrix, by default None.
197190
worker_chunk_size : int, optional
@@ -220,7 +213,6 @@ def __call__(
220213
phantom,
221214
sim_conf,
222215
handlers,
223-
smaps,
224216
coil_cov,
225217
self.model,
226218
self.slice_2d,

src/snake/core/engine/cartesian.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def _job_model_T2s(
6363
dyn_datas: list[DynamicData],
6464
sim_conf: SimConfig,
6565
trajectories: NDArray, # (Chunksize, N, 3)
66-
smaps: NDArray,
6766
slice_2d: bool = False,
6867
) -> np.ndarray:
6968
"""Acquire k-space data. With T2s decay."""
@@ -107,18 +106,20 @@ def _job_model_T2s(
107106
slice_location = flat_epi[0, 0] # FIXME: the slice is always axial.
108107
flat_epi = flat_epi[:, 1:]
109108
phantom_slice = phantom_state[:, slice_location]
110-
if smaps is None:
109+
if phantom.smaps is None:
111110
phantom_slice = phantom_slice[:, None, ...]
112111
else:
113-
smaps_ = smaps[:, slice_location]
112+
smaps_ = phantom.smaps[:, slice_location]
114113
phantom_slice = phantom_slice[:, None, ...] * smaps_
115114

116115
ksp = fft(phantom_slice, axis=(-2, -1))
117116
else:
118-
if smaps is None:
117+
if phantom.smaps is None:
119118
ksp = fft(phantom_state[:, None, ...], axis=(-3, -2, -1))
120119
else:
121-
ksp = fft(phantom_state[:, None, ...] * smaps, axis=(-3, -2, -1))
120+
ksp = fft(
121+
phantom_state[:, None, ...] * phantom.smaps, axis=(-3, -2, -1)
122+
)
122123

123124
for c in range(sim_conf.hardware.n_coils):
124125
ksp_coil_sum = np.zeros(
@@ -135,7 +136,6 @@ def _job_model_simple(
135136
dyn_datas: list[DynamicData],
136137
sim_conf: SimConfig,
137138
trajectories: NDArray, # (Chunksize, N, 3)
138-
smaps: NDArray,
139139
slice_2d: bool = False,
140140
) -> np.ndarray:
141141
"""Acquire k-space data. No T2s decay."""
@@ -157,17 +157,19 @@ def _job_model_simple(
157157
slice_location = flat_epi[0, 0] # FIXME: the slice is always axial.
158158
flat_epi = flat_epi[:, 1:] # Reduced to 2D.
159159
phantom_slice = phantom_state[slice_location]
160-
if smaps is None:
160+
if phantom.smaps is None:
161161
phantom_slice = phantom_slice[None, ...]
162162
else:
163-
smaps_ = smaps[:, slice_location]
163+
smaps_ = phantom.smaps[:, slice_location]
164164
phantom_slice = phantom_slice[None, ...] * smaps_
165165
ksp = fft(phantom_slice, axis=(-2, -1))
166166
else:
167-
if smaps is None:
167+
if phantom.smaps is None:
168168
ksp = fft(phantom_state[None, ...], axis=(-3, -2, -1))
169169
else:
170-
ksp = fft(phantom_state[None, ...] * smaps, axis=(-3, -2, -1))
170+
ksp = fft(
171+
phantom_state[None, ...] * phantom.smaps, axis=(-3, -2, -1)
172+
)
171173
for c in range(sim_conf.hardware.n_coils):
172174
ksp_coil = ksp[c]
173175
a = ksp_coil[tuple(flat_epi.T)]
@@ -247,7 +249,6 @@ def _job_model_T2s(
247249
dyn_datas: list[DynamicData],
248250
sim_conf: SimConfig,
249251
trajectories: NDArray, # (Chunksize, N, 3)
250-
smaps: NDArray,
251252
) -> np.ndarray:
252253
"""Acquire k-space data. With T2s decay."""
253254
readout_length = trajectories.shape[-2]
@@ -299,10 +300,12 @@ def _job_model_T2s(
299300
* frame_phantom.masks
300301
)
301302

302-
if smaps is None:
303+
if phantom.smaps is None:
303304
ksp = fft(phantom_state[:, None, ...], axis=(-3, -2, -1))
304305
else:
305-
ksp = fft(phantom_state[:, None, ...] * smaps, axis=(-3, -2, -1))
306+
ksp = fft(
307+
phantom_state[:, None, ...] * phantom.smaps, axis=(-3, -2, -1)
308+
)
306309
flat_evi = evi.reshape(-1, 3)
307310
for c in range(sim_conf.hardware.n_coils):
308311
ksp_coil_sum = np.zeros(
@@ -322,7 +325,6 @@ def _job_model_simple(
322325
dyn_datas: list[DynamicData],
323326
sim_conf: SimConfig,
324327
trajectories: NDArray, # (Chunksize, N, 3)
325-
smaps: NDArray,
326328
) -> np.ndarray:
327329
"""Acquire k-space data. No T2s decay."""
328330
final_ksp = np.zeros(
@@ -351,10 +353,10 @@ def _job_model_simple(
351353
* frame_phantom.masks,
352354
axis=0,
353355
)
354-
if smaps is None:
356+
if phantom.smaps is None:
355357
ksp = fft(phantom_state[None, ...], axis=(-3, -2, -1))
356358
else:
357-
ksp = fft(phantom_state[None, ...] * smaps, axis=(-3, -2, -1))
359+
ksp = fft(phantom_state[None, ...] * phantom.smaps, axis=(-3, -2, -1))
358360
flat_epi = epi_2d.reshape(-1, 3)
359361
for c in range(sim_conf.hardware.n_coils):
360362
ksp_coil = ksp[c]

src/snake/core/engine/nufft.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,6 @@ def _job_model_T2s(
9191
dyn_datas: list[DynamicData],
9292
sim_conf: SimConfig,
9393
trajectories: NDArray,
94-
smaps: NDArray,
9594
nufft_backend: str,
9695
slice_2d: bool = False,
9796
) -> np.ndarray:
@@ -105,7 +104,7 @@ def _job_model_T2s(
105104
nufft = NufftAcquisitionEngine._init_model_nufft(
106105
trajectories[0],
107106
sim_conf,
108-
smaps,
107+
phantom.smaps,
109108
backend=nufft_backend,
110109
slice_2d=slice_2d,
111110
)
@@ -121,8 +120,8 @@ def _job_model_T2s(
121120
if slice_2d:
122121
slice_loc = round((traj[0, -1] + 0.5) * sim_conf.shape[-1])
123122
nufft.samples = traj[:, :2]
124-
if smaps is not None:
125-
nufft.smaps = smaps[..., slice_loc]
123+
if phantom.smaps is not None:
124+
nufft.smaps = phantom.smaps[..., slice_loc]
126125
phantom_state = phantom_state[:, None, ..., slice_loc]
127126
else:
128127
phantom_state = phantom_state[:, None, ...]
@@ -140,7 +139,6 @@ def _job_model_simple(
140139
dyn_datas: list[DynamicData],
141140
sim_conf: SimConfig,
142141
trajectories: NDArray,
143-
smaps: NDArray,
144142
nufft_backend: str,
145143
slice_2d: bool = False,
146144
) -> np.ndarray:
@@ -153,7 +151,7 @@ def _job_model_simple(
153151
nufft = NufftAcquisitionEngine._init_model_nufft(
154152
trajectories[0],
155153
sim_conf,
156-
smaps,
154+
phantom.smaps,
157155
backend=nufft_backend,
158156
slice_2d=slice_2d,
159157
)
@@ -165,8 +163,8 @@ def _job_model_simple(
165163
if slice_2d:
166164
slice_loc = int((traj[0, -1] + 0.5) * sim_conf.shape[-1])
167165
nufft.samples = traj[:, :2]
168-
if smaps is not None:
169-
nufft.smaps = smaps[..., slice_loc]
166+
if phantom.smaps is not None:
167+
nufft.smaps = phantom.smaps[..., slice_loc]
170168
phantom_state = phantom_state[None, ..., slice_loc]
171169
else:
172170
nufft.samples = traj

src/snake/core/handlers/fov.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,10 @@
1616
from snake.core.parallel import run_parallel
1717
from snake.core.phantom import Phantom
1818
from snake.core.simulation import SimConfig
19-
19+
from snake._meta import ThreeInts, ThreeFloats
2020

2121
# TODO allow to use cupy for faster computation (if available)
2222

23-
ThreeInts = tuple[int, int, int]
24-
ThreeFloats = tuple[float, float, float]
25-
2623

2724
def extract_rotated_3d_region(
2825
volume: NDArray,
@@ -134,7 +131,13 @@ def get_static(self, phantom: Phantom, sim_conf: SimConfig) -> Phantom:
134131
),
135132
dtype=phantom.masks.dtype,
136133
)
137-
print("=======", size_vox, zoom_factor)
134+
new_smaps = np.zeros(
135+
(
136+
phantom.smaps.shape[0],
137+
*tuple(round(size_vox[i] / zoom_factor[i]) for i in range(3)),
138+
),
139+
dtype=phantom.smaps.dtype,
140+
)
138141

139142
run_parallel(
140143
_apply_transform,
@@ -146,9 +149,22 @@ def get_static(self, phantom: Phantom, sim_conf: SimConfig) -> Phantom:
146149
angles=self.angles,
147150
zoom_factor=zoom_factor,
148151
)
152+
153+
run_parallel(
154+
_apply_transform,
155+
phantom.smaps,
156+
new_smaps,
157+
parallel_axis=0,
158+
center=center_vox,
159+
size=size_vox,
160+
angles=self.angles,
161+
zoom_factor=zoom_factor,
162+
)
163+
149164
# Create a new phantom with updated masks
150165
new_phantom = phantom.copy()
151166
new_phantom.masks = new_masks
167+
new_phantom.smaps = new_smaps
152168

153169
# update the sim_config
154170
new_shape = new_phantom.anat_shape

src/snake/core/parallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def array_from_shm(
101101
shms = []
102102
arrays: list[NDArray] = []
103103
for prop in array_props:
104+
if prop is None: # optional arrays are ignored
105+
arrays.append(None)
106+
continue
104107
nbytes = int(np.dtype(prop.dtype).itemsize * np.prod(prop.shape))
105108
shms.append(SharedMemory(name=prop.name, size=nbytes))
106109
arrays.append(

0 commit comments

Comments
 (0)