Skip to content

Commit b3439c7

Browse files
committed
Fixed Spectrum1D parsing in Horne and its tests
1 parent adf54a3 commit b3439c7

File tree

2 files changed

+106
-53
lines changed

2 files changed

+106
-53
lines changed

specreduce/extract.py

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ def __call__(self, image=None, trace_object=None, width=None,
216216

217217
# extract
218218
ext1d = np.sum(self.image.data * wimg, axis=crossdisp_axis)
219-
return _to_spectrum1d_pixels(ext1d * self.image.unit)
219+
return Spectrum1D(ext1d * self.image.unit,
220+
spectral_axis=self.image.spectral_axis)
220221

221222

222223
@dataclass
@@ -280,50 +281,86 @@ class HorneExtract(SpecreduceOperation):
280281
def spectrum(self):
281282
return self.__call__()
282283

283-
def _parse_image(self, variance=None, mask=None, unit=None):
284+
def _parse_image(self, image,
285+
variance=None, mask=None, unit=None, disp_axis=1):
284286
"""
285-
Convert all accepted image types to a consistently formatted Spectrum1D.
286-
Takes some extra arguments exactly as they come from self.__call__() to
287-
handle cases where users specify them as arguments instead of as
288-
attributes of their image object.
287+
Convert all accepted image types to a consistently formatted
288+
Spectrum1D object.
289+
290+
HorneExtract needs its own version of this method because it is
291+
more stringent in its requirements for input images. The extra
292+
arguments are needed to handle cases where these parameters were
293+
specified as arguments and those where they came as attributes
294+
of the image object.
295+
296+
Parameters
297+
----------
298+
image : `~astropy.nddata.NDData`-like or array-like, required
299+
The image to be parsed. If None, defaults to class' own
300+
image attribute.
301+
variance : `~numpy.ndarray`, optional
302+
(Only used if ``image`` is not an NDData object.)
303+
The associated variances for each pixel in the image. Must
304+
have the same dimensions as ``image``. If all zeros, the variance
305+
will be ignored and treated as all ones. If any zeros, those
306+
elements will be excluded via masking. If any negative values,
307+
an error will be raised.
308+
mask : `~numpy.ndarray`, optional
309+
(Only used if ``image`` is not an NDData object.)
310+
Whether to mask each pixel in the image. Must have the same
311+
dimensions as ``image``. If blank, all non-NaN pixels are
312+
unmasked.
313+
unit : `~astropy.units.Unit` or str, optional
314+
(Only used if ``image`` is not an NDData object.)
315+
The associated unit for the data in ``image``. If blank,
316+
fluxes are interpreted as unitless.
317+
disp_axis : int, optional
318+
The index of the image's dispersion axis. Should not be
319+
changed until operations can handle variable image
320+
orientations. [default: 1]
289321
"""
290322

291-
if isinstance(self.image, np.ndarray):
292-
img = self.image
293-
elif isinstance(self.image, u.quantity.Quantity):
294-
img = self.image.value
323+
if isinstance(image, np.ndarray):
324+
img = image
325+
elif isinstance(image, u.quantity.Quantity):
326+
img = image.value
295327
else: # NDData, including CCDData and Spectrum1D
296-
img = self.image.data
328+
img = image.data
297329

298330
# mask is set as None when not specified upon creating a Spectrum1D
299331
# object, so we must check whether it is absent *and* whether it's
300332
# present but set as None
301-
if getattr(self.image, 'mask', None) is not None:
302-
mask = self.image.mask
333+
if getattr(image, 'mask', None) is not None:
334+
mask = image.mask
335+
elif mask is not None:
336+
pass
303337
else:
304338
mask = np.ma.masked_invalid(img).mask
305339

340+
if img.shape != mask.shape:
341+
raise ValueError('image and mask shapes must match.')
342+
306343
# Process uncertainties, converting to variances when able and throwing
307344
# an error when uncertainties are missing or less easily converted
308-
if (hasattr(self.image, 'uncertainty')
309-
and self.image.uncertainty is not None):
310-
if self.image.uncertainty.uncertainty_type == 'var':
311-
variance = self.image.uncertainty.array
312-
elif self.image.uncertainty.uncertainty_type == 'std':
345+
if (hasattr(image, 'uncertainty')
346+
and image.uncertainty is not None):
347+
if image.uncertainty.uncertainty_type == 'var':
348+
variance = image.uncertainty.array
349+
elif image.uncertainty.uncertainty_type == 'std':
313350
warnings.warn("image NDData object's uncertainty "
314351
"interpreted as standard deviation. if "
315352
"incorrect, use VarianceUncertainty when "
316353
"assigning image object's uncertainty.")
317-
variance = self.image.uncertainty.array**2
318-
elif self.image.uncertainty.uncertainty_type == 'ivar':
319-
variance = 1 / self.image.uncertainty.array
354+
variance = image.uncertainty.array**2
355+
elif image.uncertainty.uncertainty_type == 'ivar':
356+
variance = 1 / image.uncertainty.array
320357
else:
321-
# other options are InverseVariance and UnknownVariance
358+
# other options are InverseUncertainty and UnknownUncertainty
322359
raise ValueError("image NDData object has unexpected "
323360
"uncertainty type. instead, try "
324361
"VarianceUncertainty or StdDevUncertainty.")
325-
elif (hasattr(self.image, 'uncertainty')
326-
and self.image.uncertainty is None):
362+
elif (hasattr(image, 'uncertainty')
363+
and image.uncertainty is None):
327364
# ignore variance arg to focus on updating NDData object
328365
raise ValueError('image NDData object lacks uncertainty')
329366
else:
@@ -332,7 +369,7 @@ def _parse_image(self, variance=None, mask=None, unit=None):
332369
"variance must be specified. consider "
333370
"wrapping it into one object by instead "
334371
"passing an NDData image.")
335-
elif self.image.shape != variance.shape:
372+
elif image.shape != variance.shape:
336373
raise ValueError("image and variance shapes must match")
337374

338375
if np.any(variance < 0):
@@ -349,16 +386,14 @@ def _parse_image(self, variance=None, mask=None, unit=None):
349386

350387
variance = VarianceUncertainty(variance)
351388

352-
unit = getattr(self.image, 'unit',
353-
u.Unit(self.unit) if self.unit is not None else u.Unit())
389+
unit = getattr(image, 'unit',
390+
u.Unit(unit) if unit is not None else u.Unit())
354391

355-
spectral_axis = getattr(self.image, 'spectral_axis',
356-
(np.arange(img.shape[self.disp_axis])
357-
if hasattr(self, 'disp_axis')
358-
else np.arange(img.shape[1])) * u.pix)
392+
spectral_axis = getattr(image, 'spectral_axis',
393+
np.arange(img.shape[disp_axis]) * u.pix)
359394

360-
self.image = Spectrum1D(img * unit, spectral_axis=spectral_axis,
361-
uncertainty=variance, mask=mask)
395+
return Spectrum1D(img * unit, spectral_axis=spectral_axis,
396+
uncertainty=variance, mask=mask)
362397

363398
def __call__(self, image=None, trace_object=None,
364399
disp_axis=None, crossdisp_axis=None,
@@ -423,7 +458,7 @@ def __call__(self, image=None, trace_object=None,
423458
unit = unit if unit is not None else self.unit
424459

425460
# parse image and replace optional arguments with updated values
426-
self._parse_image(variance, mask, unit)
461+
self.image = self._parse_image(image, variance, mask, unit, disp_axis)
427462
variance = self.image.uncertainty.array
428463
unit = self.image.unit
429464

specreduce/tests/test_extract.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33

44
import astropy.units as u
5-
from astropy.nddata import CCDData
5+
from astropy.nddata import CCDData, VarianceUncertainty, UnknownUncertainty
66

77
from specreduce.extract import BoxcarExtract, HorneExtract, OptimalExtract
88
from specreduce.tracing import FlatTrace, ArrayTrace
@@ -79,7 +79,7 @@ def test_boxcar_array_trace():
7979
assert np.allclose(spectrum.flux.value, np.full_like(spectrum.flux.value, 75.))
8080

8181

82-
def test_horne_array_validation():
82+
def test_horne_image_validation():
8383
#
8484
# Test HorneExtract scenarios specific to its use with an image of
8585
# type `~numpy.ndarray` (instead of the default `~astropy.nddata.NDData`).
@@ -88,47 +88,65 @@ def test_horne_array_validation():
8888
extract = OptimalExtract(image.data, trace) # equivalent to HorneExtract
8989

9090
# an array-type image must come with a variance argument
91-
with pytest.raises(ValueError, match=r'.*array.*variance.*specified.*'):
91+
with pytest.raises(ValueError, match=r'.*variance must be specified.*'):
9292
ext = extract()
9393

94+
# an NDData-type image can't have an empty uncertainty attribute
95+
with pytest.raises(ValueError, match=r'.*NDData object lacks uncertainty'):
96+
ext = extract(image=image)
97+
98+
# an NDData-type image's uncertainty must be of type VarianceUncertainty
99+
# or type StdDevUncertainty
100+
with pytest.raises(ValueError, match=r'.*unexpected uncertainty type.*'):
101+
err = UnknownUncertainty(np.ones_like(image))
102+
image.uncertainty = err
103+
ext = extract(image=image)
104+
94105
# an array-type image must have the same dimensions as its variance argument
95106
with pytest.raises(ValueError, match=r'.*shapes must match.*'):
107+
# remember variance, mask, and unit args are only checked if image
108+
# object doesn't have those attributes (e.g., numpy and Quantity arrays)
96109
err = np.ones_like(image[0])
97-
ext = extract(variance=err)
110+
ext = extract(image=image.data, variance=err)
98111

99112
# an array-type image must have the same dimensions as its mask argument
100113
with pytest.raises(ValueError, match=r'.*shapes must match.*'):
101114
err = np.ones_like(image)
102115
mask = np.zeros_like(image[0])
103-
ext = extract(variance=err, mask=mask)
116+
ext = extract(image=image.data, variance=err, mask=mask)
104117

105118
# an array-type image given without mask and unit arguments is fine
106-
# and produces a unitless result
119+
# and produces an extraction with unitless flux and spectral axis in pixels
107120
err = np.ones_like(image)
108-
ext = extract(variance=err)
121+
ext = extract(image=image.data, variance=err, mask=None, unit=None)
109122
assert ext.unit == u.Unit()
123+
assert np.all(ext.spectral_axis
124+
== np.arange(image.shape[extract.disp_axis]) * u.pix)
110125

111126

112127
def test_horne_variance_errors():
113128
trace = FlatTrace(image, 3.0)
114129

115-
# all zeros are treated as non-weighted (give non-zero fluxes)
116-
err = np.zeros_like(image)
117-
mask = np.zeros_like(image)
118-
extract = HorneExtract(image.data, trace, variance=err, mask=mask, unit=u.Jy)
130+
# all zeros are treated as non-weighted (i.e., as non-zero fluxes)
131+
image.uncertainty = VarianceUncertainty(np.zeros_like(image))
132+
image.mask = np.zeros_like(image)
133+
extract = HorneExtract(image, trace)
119134
ext = extract.spectrum
120135
assert not np.all(ext == 0)
121136

122-
# single zero value adjusts mask (does not raise error)
137+
# single zero value adjusts mask and does not raise error
123138
err = np.ones_like(image)
124-
err[0] = 0
125-
mask = np.zeros_like(image)
126-
ext = extract(variance=err, mask=mask, unit=u.Jy)
127-
assert not np.all(ext == 0)
139+
err[0][0] = 0
140+
image.uncertainty = VarianceUncertainty(err)
141+
ext = extract(image)
142+
assert not np.all(ext == 1)
128143

129144
# single negative value raises error
130-
err = np.ones_like(image)
131-
err[0] = -1
145+
err = image.uncertainty.array
146+
err[0][0] = -1
132147
mask = np.zeros_like(image)
133148
with pytest.raises(ValueError, match='variance must be fully positive'):
134-
ext = extract(variance=err, mask=mask, unit=u.Jy)
149+
# remember variance, mask, and unit args are only checked if image
150+
# object doesn't have those attributes (e.g., numpy and Quantity arrays)
151+
ext = extract(image=image.data, variance=err,
152+
mask=image.mask, unit=u.Jy)

0 commit comments

Comments
 (0)