@@ -97,7 +97,6 @@ def apply(
97
97
98
98
# Avoid opening the data array just yet
99
99
input_dtype = get_obj_dtype (spatialimage .dataobj )
100
- output_dtype = output_dtype or input_dtype
101
100
102
101
# Number of transformations
103
102
data_nvols = 1 if spatialimage .ndim < 4 else spatialimage .shape [- 1 ]
@@ -115,16 +114,17 @@ def apply(
115
114
serialize_4d = n_resamplings >= serialize_nvols
116
115
117
116
targets = None
117
+ ref_ndcoords = _ref .ndcoords .T
118
118
if hasattr (transform , "to_field" ) and callable (transform .to_field ):
119
119
targets = ImageGrid (spatialimage ).index (
120
120
_as_homogeneous (
121
- transform .to_field (reference = reference ).map (_ref . ndcoords . T ),
121
+ transform .to_field (reference = reference ).map (ref_ndcoords ),
122
122
dim = _ref .ndim ,
123
123
)
124
124
)
125
125
elif xfm_nvols == 1 :
126
126
targets = ImageGrid (spatialimage ).index ( # data should be an image
127
- _as_homogeneous (transform .map (_ref . ndcoords . T ), dim = _ref .ndim )
127
+ _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
128
128
)
129
129
130
130
if serialize_4d :
@@ -137,15 +137,15 @@ def apply(
137
137
# Order F ensures individual volumes are contiguous in memory
138
138
# Also matches NIfTI, making final save more efficient
139
139
resampled = np .zeros (
140
- (spatialimage . size , len (transform )), dtype = output_dtype , order = "F"
140
+ (len ( ref_ndcoords ) , len (transform )), dtype = input_dtype , order = "F"
141
141
)
142
142
143
143
for t in range (n_resamplings ):
144
144
xfm_t = transform if n_resamplings == 1 else transform [t ]
145
145
146
146
if targets is None :
147
147
targets = ImageGrid (spatialimage ).index ( # data should be an image
148
- _as_homogeneous (xfm_t .map (_ref . ndcoords . T ), dim = _ref .ndim )
148
+ _as_homogeneous (xfm_t .map (ref_ndcoords ), dim = _ref .ndim )
149
149
)
150
150
151
151
# Interpolate
@@ -156,7 +156,6 @@ def apply(
156
156
else spatialimage .dataobj [..., t ].astype (input_dtype , copy = False )
157
157
),
158
158
targets ,
159
- output = output_dtype ,
160
159
order = order ,
161
160
mode = mode ,
162
161
cval = cval ,
@@ -168,7 +167,7 @@ def apply(
168
167
169
168
if targets is None :
170
169
targets = ImageGrid (spatialimage ).index ( # data should be an image
171
- _as_homogeneous (transform .map (_ref . ndcoords . T ), dim = _ref .ndim )
170
+ _as_homogeneous (transform .map (ref_ndcoords ), dim = _ref .ndim )
172
171
)
173
172
174
173
# Cast 3D data into 4D if 4D nonsequential transform
@@ -181,7 +180,6 @@ def apply(
181
180
resampled = ndi .map_coordinates (
182
181
data ,
183
182
targets ,
184
- output = output_dtype ,
185
183
order = order ,
186
184
mode = mode ,
187
185
cval = cval ,
@@ -190,13 +188,14 @@ def apply(
190
188
191
189
if isinstance (_ref , ImageGrid ): # If reference is grid, reshape
192
190
hdr = _ref .header .copy () if _ref .header is not None else spatialimage .header .__class__ ()
193
- hdr .set_data_dtype (output_dtype )
191
+ hdr .set_data_dtype (output_dtype or spatialimage . header . get_data_dtype () )
194
192
195
193
moved = spatialimage .__class__ (
196
- resampled .reshape (_ref .shape if data . ndim < 4 else _ref .shape + (- 1 ,)),
194
+ resampled .reshape (_ref .shape if n_resamplings == 1 else _ref .shape + (- 1 ,)),
197
195
_ref .affine ,
198
196
hdr ,
199
197
)
200
198
return moved
201
199
202
- return resampled
200
+ output_dtype = output_dtype or input_dtype
201
+ return resampled .astype (output_dtype )
0 commit comments