10
10
11
11
import asyncio
12
12
from os import cpu_count
13
+ from contextlib import suppress
13
14
from functools import partial
14
15
from pathlib import Path
15
16
from typing import Callable , TypeVar , Union
@@ -108,14 +109,16 @@ async def _apply_serial(
108
109
semaphore = asyncio .Semaphore (max_concurrent )
109
110
110
111
for t in range (n_resamplings ):
111
- xfm_t = transform if (n_resamplings == 1 or transform .ndim < 4 ) else transform [t ]
112
+ xfm_t = (
113
+ transform if (n_resamplings == 1 or transform .ndim < 4 ) else transform [t ]
114
+ )
112
115
113
116
targets_t = (
114
117
ImageGrid (spatialimage ).index (
115
118
_as_homogeneous (xfm_t .map (ref_ndcoords ), dim = ref_ndim )
116
119
)
117
120
if targets is None
118
- else targets
121
+ else targets [ t , ...]
119
122
)
120
123
121
124
data_t = (
@@ -258,11 +261,22 @@ def apply(
258
261
dim = _ref .ndim ,
259
262
)
260
263
)
261
- elif xfm_nvols == 1 :
262
- targets = ImageGrid (spatialimage ).index ( # data should be an image
263
- _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
264
+ else :
265
+ # Targets' shape is (Nt, 3, Nv) with Nv = Num. voxels, Nt = Num. timepoints.
266
+ targets = (
267
+ ImageGrid (spatialimage ).index (
268
+ _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
269
+ )
270
+ if targets is None
271
+ else targets
264
272
)
265
273
274
+ if targets .ndim == 3 :
275
+ targets = np .rollaxis (targets , targets .ndim - 1 , 0 )
276
+ else :
277
+ assert targets .ndim == 2
278
+ targets = targets [np .newaxis , ...]
279
+
266
280
if serialize_4d :
267
281
data = (
268
282
np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
@@ -297,17 +311,24 @@ def apply(
297
311
else :
298
312
data = np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
299
313
300
- if targets is None :
301
- targets = ImageGrid (spatialimage ).index ( # data should be an image
302
- _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
303
- )
304
-
314
+ if data_nvols == 1 and xfm_nvols == 1 :
315
+ targets = np .squeeze (targets )
316
+ assert targets .ndim == 2
305
317
# Cast 3D data into 4D if 4D nonsequential transform
306
- if data_nvols == 1 and xfm_nvols > 1 :
318
+ elif data_nvols == 1 and xfm_nvols > 1 :
307
319
data = data [..., np .newaxis ]
308
320
309
- if transform .ndim == 4 :
310
- targets = _as_homogeneous (targets .reshape (- 2 , targets .shape [0 ])).T
321
+ if xfm_nvols > 1 :
322
+ assert targets .ndim == 3
323
+ n_time , n_dim , n_vox = targets .shape
324
+ # Reshape to (3, n_time x n_vox)
325
+ ijk_targets = np .rollaxis (targets , 0 , 2 ).reshape ((n_dim , - 1 ))
326
+ time_row = np .repeat (np .arange (n_time ), n_vox )[None , :]
327
+
328
+ # Now targets is (4, n_vox x n_time), with indexes (t, i, j, k)
329
+ # t is the slowest-changing axis, so we put it first
330
+ targets = np .vstack ((time_row , ijk_targets ))
331
+ data = np .rollaxis (data , data .ndim - 1 , 0 )
311
332
312
333
resampled = ndi .map_coordinates (
313
334
data ,
@@ -326,11 +347,19 @@ def apply(
326
347
)
327
348
hdr .set_data_dtype (output_dtype or spatialimage .header .get_data_dtype ())
328
349
329
- moved = spatialimage .__class__ (
330
- resampled .reshape (_ref .shape if n_resamplings == 1 else _ref .shape + (- 1 ,)),
331
- _ref .affine ,
332
- hdr ,
333
- )
350
+ if serialize_4d :
351
+ resampled = resampled .reshape (
352
+ _ref .shape
353
+ if n_resamplings == 1
354
+ else _ref .shape + (resampled .shape [- 1 ],)
355
+ )
356
+ else :
357
+ resampled = resampled .reshape ((- 1 , * _ref .shape ))
358
+ resampled = np .rollaxis (resampled , 0 , resampled .ndim )
359
+ with suppress (ValueError ):
360
+ resampled = np .squeeze (resampled , axis = 3 )
361
+
362
+ moved = spatialimage .__class__ (resampled , _ref .affine , hdr )
334
363
return moved
335
364
336
365
output_dtype = output_dtype or input_dtype
0 commit comments