10
10
11
11
import asyncio
12
12
from os import cpu_count
13
+ from contextlib import suppress
13
14
from functools import partial
14
15
from pathlib import Path
15
16
from typing import Callable , TypeVar , Union
@@ -108,12 +109,17 @@ async def _apply_serial(
108
109
semaphore = asyncio .Semaphore (max_concurrent )
109
110
110
111
for t in range (n_resamplings ):
111
- xfm_t = transform if (n_resamplings == 1 or transform .ndim < 4 ) else transform [t ]
112
+ xfm_t = (
113
+ transform if (n_resamplings == 1 or transform .ndim < 4 ) else transform [t ]
114
+ )
112
115
113
- if targets is None :
114
- targets = ImageGrid (spatialimage ).index ( # data should be an image
116
+ targets_t = (
117
+ ImageGrid (spatialimage ).index (
115
118
_as_homogeneous (xfm_t .map (ref_ndcoords ), dim = ref_ndim )
116
119
)
120
+ if targets is None
121
+ else targets [t , ...]
122
+ )
117
123
118
124
data_t = (
119
125
data
@@ -127,7 +133,7 @@ async def _apply_serial(
127
133
partial (
128
134
ndi .map_coordinates ,
129
135
data_t ,
130
- targets ,
136
+ targets_t ,
131
137
output = output [..., t ],
132
138
order = order ,
133
139
mode = mode ,
@@ -255,11 +261,22 @@ def apply(
255
261
dim = _ref .ndim ,
256
262
)
257
263
)
258
- elif xfm_nvols == 1 :
259
- targets = ImageGrid (spatialimage ).index ( # data should be an image
260
- _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
264
+ else :
265
+ # Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints.
266
+ targets = (
267
+ ImageGrid (spatialimage ).index (
268
+ _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
269
+ )
270
+ if targets is None
271
+ else targets
261
272
)
262
273
274
+ if targets .ndim == 3 :
275
+ targets = np .rollaxis (targets , targets .ndim - 1 , 0 )
276
+ else :
277
+ assert targets .ndim == 2
278
+ targets = targets [np .newaxis , ...]
279
+
263
280
if serialize_4d :
264
281
data = (
265
282
np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
@@ -294,17 +311,24 @@ def apply(
294
311
else :
295
312
data = np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
296
313
297
- if targets is None :
298
- targets = ImageGrid (spatialimage ).index ( # data should be an image
299
- _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
300
- )
301
-
314
+ if data_nvols == 1 and xfm_nvols == 1 :
315
+ targets = np .squeeze (targets )
316
+ assert targets .ndim == 2
302
317
# Cast 3D data into 4D if 4D nonsequential transform
303
- if data_nvols == 1 and xfm_nvols > 1 :
318
+ elif data_nvols == 1 and xfm_nvols > 1 :
304
319
data = data [..., np .newaxis ]
305
320
306
- if transform .ndim == 4 :
307
- targets = _as_homogeneous (targets .reshape (- 2 , targets .shape [0 ])).T
321
+ if xfm_nvols > 1 :
322
+ assert targets .ndim == 3
323
+ n_time , n_dim , n_vox = targets .shape
324
+ # Reshape to (3, n_time x n_vox)
325
+ ijk_targets = np .rollaxis (targets , 0 , 2 ).reshape ((n_dim , - 1 ))
326
+ time_row = np .repeat (np .arange (n_time ), n_vox )[None , :]
327
+
328
+ # Now targets is (4, n_vox x n_time), with indexes (t, i, j, k)
329
+ # t is the slowest-changing axis, so we put it first
330
+ targets = np .vstack ((time_row , ijk_targets ))
331
+ data = np .rollaxis (data , data .ndim - 1 , 0 )
308
332
309
333
resampled = ndi .map_coordinates (
310
334
data ,
@@ -323,11 +347,19 @@ def apply(
323
347
)
324
348
hdr .set_data_dtype (output_dtype or spatialimage .header .get_data_dtype ())
325
349
326
- moved = spatialimage .__class__ (
327
- resampled .reshape (_ref .shape if n_resamplings == 1 else _ref .shape + (- 1 ,)),
328
- _ref .affine ,
329
- hdr ,
330
- )
350
+ if serialize_4d :
351
+ resampled = resampled .reshape (
352
+ _ref .shape
353
+ if n_resamplings == 1
354
+ else _ref .shape + (resampled .shape [- 1 ],)
355
+ )
356
+ else :
357
+ resampled = resampled .reshape ((- 1 , * _ref .shape ))
358
+ resampled = np .rollaxis (resampled , 0 , resampled .ndim )
359
+ with suppress (ValueError ):
360
+ resampled = np .squeeze (resampled , axis = 3 )
361
+
362
+ moved = spatialimage .__class__ (resampled , _ref .affine , hdr )
331
363
return moved
332
364
333
365
output_dtype = output_dtype or input_dtype
0 commit comments