Skip to content

Commit a64f20b

Browse files
committed
Update allowed input types for stretch and cuts
1 parent b9891ac commit a64f20b

File tree

3 files changed

+24
-23
lines changed

3 files changed

+24
-23
lines changed

src/astro_image_display_api/dummy_viewer.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from astropy.table import Table, vstack
99
from astropy.units import Quantity, get_physical_type
1010
from astropy.wcs import WCS
11+
from astropy.visualization import BaseInterval, BaseStretch, ManualInterval
1112
from numpy.typing import ArrayLike
1213

1314
from .interface_definition import ImageViewerInterface
@@ -74,12 +75,12 @@ def click_drag(self, value: bool) -> None:
7475
self._click_center = not value
7576

7677
@property
77-
def stretch(self) -> str:
78+
def stretch(self) -> BaseStretch:
7879
return self._stretch
7980

8081
@stretch.setter
81-
def stretch(self, value: str) -> None:
82-
if value not in self.stretch_options:
82+
def stretch(self, value: BaseStretch) -> None:
83+
if not isinstance(value, BaseStretch):
8384
raise ValueError(f"Stretch option {value} is not valid. Must be one of {self.stretch_options}.")
8485
self._stretch = value
8586

@@ -89,16 +90,13 @@ def cuts(self) -> tuple:
8990

9091
@cuts.setter
9192
def cuts(self, value: tuple) -> None:
92-
if isinstance(value, str):
93-
if value not in self.autocut_options:
94-
raise ValueError(f"Cut option {value} is not valid. Must be one of {self.autocut_options}.")
95-
# A real viewer would calculate the cuts based on the data
96-
self._cuts = (0, 1)
97-
return
98-
99-
if len(value) != 2:
100-
raise ValueError("Cuts must have length 2.")
101-
self._cuts = value
93+
if isinstance(value, tuple) and len(value) == 2:
94+
self._cuts = ManualInterval(value[0], value[1])
95+
elif isinstance(value, BaseInterval):
96+
self._cuts = value
97+
else:
98+
raise ValueError("Cuts must be an Astropy.visualization Interval object or a tuple of two values.")
99+
102100

103101
@property
104102
def cursor(self) -> str:

src/astro_image_display_api/interface_definition.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from astropy.nddata import NDData
77
from astropy.table import Table
88
from astropy.units import Quantity
9+
from astropy.visualization import BaseInterval, BaseStretch
910

1011
from numpy.typing import ArrayLike
1112

@@ -37,8 +38,8 @@ class ImageViewerInterface(Protocol):
3738
autocut_options: tuple
3839
cursor: str
3940
marker: Any
40-
cuts: Any
41-
stretch: str
41+
cuts: tuple | BaseInterval
42+
stretch: BaseStretch
4243
# viewer: Any
4344

4445
# Allowed locations for cursor display

src/astro_image_display_api/widget_api_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from astropy.table import Table, vstack # noqa: E402
1111
from astropy import units as u # noqa: E402
1212
from astropy.wcs import WCS # noqa: E402
13+
from astropy.visualization import AsymmetricPercentileInterval, LogStretch, ManualInterval
1314

1415
__all__ = ['ImageWidgetAPITest']
1516

@@ -280,26 +281,27 @@ def test_stretch(self):
280281
# A bad value should leave the stretch unchanged
281282
assert self.image.stretch is original_stretch
282283

283-
self.image.stretch = 'log'
284+
self.image.stretch = LogStretch()
284285
# A valid value should change the stretch
285286
assert self.image.stretch is not original_stretch
287+
assert isinstance(self.image.stretch, LogStretch)
286288

287289
def test_cuts(self, data):
288-
with pytest.raises(ValueError, match='[mM]ust be one of'):
290+
assert len(self.image.autocut_options) > 0
291+
with pytest.raises(ValueError, match='[mM]ust be'):
289292
self.image.cuts = 'not a valid value'
290293

291-
with pytest.raises(ValueError, match='must have length 2'):
294+
with pytest.raises(ValueError, match='[mM]ust be'):
292295
self.image.cuts = (1, 10, 100)
293296

294-
assert 'histogram' in self.image.autocut_options
295-
296297
# Setting using histogram requires data
297298
self.image.load_array(data)
298-
self.image.cuts = 'histogram'
299-
assert len(self.image.cuts) == 2
299+
self.image.cuts = AsymmetricPercentileInterval(0.1, 99.9)
300+
assert isinstance(self.image.cuts, AsymmetricPercentileInterval)
300301

301302
self.image.cuts = (10, 100)
302-
assert self.image.cuts == (10, 100)
303+
assert isinstance(self.image.cuts, ManualInterval)
304+
assert self.image.cuts.get_limits(data) == (10, 100)
303305

304306
@pytest.mark.skip(reason="Not clear whether colormap is part of the API")
305307
def test_colormap(self):

0 commit comments

Comments
 (0)