7
7
#
8
8
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
9
9
"""Resampling utilities."""
10
- from warnings import warn
10
+
11
11
from pathlib import Path
12
12
import numpy as np
13
13
from nibabel .loadsave import load as _nbload
14
14
from nibabel .arrayproxy import get_obj_dtype
15
15
from scipy import ndimage as ndi
16
16
17
- from nitransforms .linear import Affine , LinearTransformsMapping
18
17
from nitransforms .base import (
19
18
ImageGrid ,
20
19
TransformError ,
21
20
SpatialReference ,
22
21
_as_homogeneous ,
23
22
)
24
23
25
- SERIALIZE_VOLUME_WINDOW_WIDTH : int = 8
24
+ SERIALIZE_VOLUME_WINDOW_WIDTH : int = 8
26
25
"""Minimum number of volumes to automatically serialize 4D transforms."""
27
26
28
27
@@ -96,58 +95,67 @@ def apply(
96
95
if isinstance (spatialimage , (str , Path )):
97
96
spatialimage = _nbload (str (spatialimage ))
98
97
99
- data = np .asanyarray (spatialimage .dataobj )
100
- data_nvols = 1 if data .ndim < 4 else data .shape [- 1 ]
98
+ # Avoid opening the data array just yet
99
+ input_dtype = get_obj_dtype (spatialimage .dataobj )
100
+ output_dtype = output_dtype or input_dtype
101
101
102
+ # Number of transformations
103
+ data_nvols = 1 if spatialimage .ndim < 4 else spatialimage .shape [- 1 ]
102
104
xfm_nvols = len (transform )
103
105
104
- if data_nvols == 1 and xfm_nvols > 1 :
105
- data = data [..., np .newaxis ]
106
- elif data_nvols != xfm_nvols :
106
+ if data_nvols != xfm_nvols and min (data_nvols , xfm_nvols ) > 1 :
107
107
raise ValueError (
108
108
"The fourth dimension of the data does not match the transform's shape."
109
109
)
110
110
111
- serialize_nvols = serialize_nvols if serialize_nvols and serialize_nvols > 1 else np .inf
112
- serialize_4d = max (data_nvols , xfm_nvols ) >= serialize_nvols
111
+ serialize_nvols = (
112
+ serialize_nvols if serialize_nvols and serialize_nvols > 1 else np .inf
113
+ )
114
+ n_resamplings = max (data_nvols , xfm_nvols )
115
+ serialize_4d = n_resamplings >= serialize_nvols
116
+
117
+ targets = None
118
+ if hasattr (transform , "to_field" ) and callable (transform .to_field ):
119
+ targets = ImageGrid (spatialimage ).index (
120
+ _as_homogeneous (
121
+ transform .to_field (reference = reference ).map (_ref .ndcoords .T ),
122
+ dim = _ref .ndim ,
123
+ )
124
+ )
125
+ elif xfm_nvols == 1 :
126
+ targets = ImageGrid (spatialimage ).index ( # data should be an image
127
+ _as_homogeneous (transform .map (_ref .ndcoords .T ), dim = _ref .ndim )
128
+ )
113
129
114
130
if serialize_4d :
115
- # Avoid opening the data array just yet
116
- input_dtype = get_obj_dtype (spatialimage .dataobj )
117
- output_dtype = output_dtype or input_dtype
118
-
119
- # Prepare physical coordinates of input (grid, points)
120
- xcoords = _ref .ndcoords .astype ("f4" ).T
121
-
122
- # Invert target's (moving) affine once
123
- ras2vox = ~ Affine (spatialimage .affine )
124
- dataobj = (
131
+ data = (
125
132
np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
126
- if spatialimage . ndim in ( 2 , 3 )
133
+ if data_nvols == 1
127
134
else None
128
135
)
129
136
130
137
# Order F ensures individual volumes are contiguous in memory
131
138
# Also matches NIfTI, making final save more efficient
132
139
resampled = np .zeros (
133
- (xcoords . shape [ 0 ] , len (transform )), dtype = output_dtype , order = "F"
140
+ (spatialimage . size , len (transform )), dtype = output_dtype , order = "F"
134
141
)
135
142
136
- for t , xfm_t in enumerate (transform ):
137
- # Map the input coordinates on to timepoint t of the target (moving)
138
- ycoords = xfm_t .map (xcoords )[..., : _ref .ndim ]
143
+ for t in range (n_resamplings ):
144
+ xfm_t = transform if n_resamplings == 1 else transform [t ]
139
145
140
- # Calculate corresponding voxel coordinates
141
- yvoxels = ras2vox .map (ycoords )[..., : _ref .ndim ]
146
+ if targets is None :
147
+ targets = ImageGrid (spatialimage ).index ( # data should be an image
148
+ _as_homogeneous (xfm_t .map (_ref .ndcoords .T ), dim = _ref .ndim )
149
+ )
142
150
143
151
# Interpolate
144
152
resampled [..., t ] = ndi .map_coordinates (
145
153
(
146
- dataobj
147
- if dataobj is not None
154
+ data
155
+ if data is not None
148
156
else spatialimage .dataobj [..., t ].astype (input_dtype , copy = False )
149
157
),
150
- yvoxels . T ,
158
+ targets ,
151
159
output = output_dtype ,
152
160
order = order ,
153
161
mode = mode ,
@@ -156,19 +164,17 @@ def apply(
156
164
)
157
165
158
166
else :
159
- # For model-based nonlinear transforms, generate the corresponding dense field
160
- if hasattr (transform , "to_field" ) and callable (transform .to_field ):
161
- targets = ImageGrid (spatialimage ).index (
162
- _as_homogeneous (
163
- transform .to_field (reference = reference ).map (_ref .ndcoords .T ),
164
- dim = _ref .ndim ,
165
- )
166
- )
167
- else :
167
+ data = np .asanyarray (spatialimage .dataobj , dtype = input_dtype )
168
+
169
+ if targets is None :
168
170
targets = ImageGrid (spatialimage ).index ( # data should be an image
169
171
_as_homogeneous (transform .map (_ref .ndcoords .T ), dim = _ref .ndim )
170
172
)
171
173
174
+ # Cast 3D data into 4D if 4D nonsequential transform
175
+ if data_nvols == 1 and xfm_nvols > 1 :
176
+ data = data [..., np .newaxis ]
177
+
172
178
if transform .ndim == 4 :
173
179
targets = _as_homogeneous (targets .reshape (- 2 , targets .shape [0 ])).T
174
180
0 commit comments