Skip to content

Commit 761193f

Browse files
committed
Improve type matching to make tests pass
1 parent 28c5245 commit 761193f

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

reproject/adaptive/core.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,15 @@ def _reproject_adaptive_2d(
117117
"broadcast" across those images.
118118
"""
119119

120-
# Make sure image is floating point, but preserve the floating-point precision if it's floating point already
120+
# Make sure the input is floating point, and that it matches the output array type if specified. (Cython makes 32
121+
# and 64-bit versions of the inner function, but it's either everything is 32-bit, or everything is 64-bit). In
122+
# case of a disagreement between the input and output types, it's best to apply the output's type to the
123+
# input---the input is more likely to have weirdness (e.g. the wrong endianness, which Cython can't handle), and
124+
# the output type is either a safe default, if the user didn't provide an output array, or it has a user-specified
125+
# type.
121126
array_in = np.asarray(array)
122-
if not np.issubdtype(array.dtype, np.floating):
123-
array_in = np.asarray(array, dtype=float)
127+
if array_out is not None and array_in.dtype != array_out.dtype:
128+
array_in = np.asarray(array, dtype=array_out.dtype)
124129
shape_out = tuple(shape_out)
125130

126131
# Check dimensionality of WCS and shape_out
@@ -141,7 +146,9 @@ def _reproject_adaptive_2d(
141146
raise ValueError("Dimensions to be looped over must match exactly")
142147

143148
if array_out is None:
144-
array_out = np.empty(shape_out)
149+
# n.b. in the normal calling sequence, array_out is always generated by `_reproject_dispatcher` and passed into
150+
# this function.
151+
array_out = np.empty(shape_out, dtype=array_in.dtype)
145152

146153
if output_footprint is None:
147154
output_footprint = np.empty(shape_out)

0 commit comments

Comments
 (0)