Skip to content

Commit 692e74b

Browse files
committed
Update allowed input types for stretch and cuts
1 parent 425ae4c commit 692e74b

File tree

3 files changed

+24
-23
lines changed

3 files changed

+24
-23
lines changed

src/astro_image_display_api/dummy_viewer.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from astropy.table import Table, vstack
99
from astropy.units import Quantity, get_physical_type
1010
from astropy.wcs import WCS
11+
from astropy.visualization import BaseInterval, BaseStretch, ManualInterval
1112
from numpy.typing import ArrayLike
1213

1314
from .interface_definition import ImageViewerInterface
@@ -50,12 +51,12 @@ class ImageViewer:
5051
_center: tuple[float, float] = (0.0, 0.0)
5152

5253
@property
53-
def stretch(self) -> str:
54+
def stretch(self) -> BaseStretch:
5455
return self._stretch
5556

5657
@stretch.setter
57-
def stretch(self, value: str) -> None:
58-
if value not in self.stretch_options:
58+
def stretch(self, value: BaseStretch) -> None:
59+
if not isinstance(value, BaseStretch):
5960
raise ValueError(f"Stretch option {value} is not valid. Must be one of {self.stretch_options}.")
6061
self._stretch = value
6162

@@ -65,16 +66,13 @@ def cuts(self) -> tuple:
6566

6667
@cuts.setter
6768
def cuts(self, value: tuple) -> None:
68-
if isinstance(value, str):
69-
if value not in self.autocut_options:
70-
raise ValueError(f"Cut option {value} is not valid. Must be one of {self.autocut_options}.")
71-
# A real viewer would calculate the cuts based on the data
72-
self._cuts = (0, 1)
73-
return
74-
75-
if len(value) != 2:
76-
raise ValueError("Cuts must have length 2.")
77-
self._cuts = value
69+
if isinstance(value, tuple) and len(value) == 2:
70+
self._cuts = ManualInterval(value[0], value[1])
71+
elif isinstance(value, BaseInterval):
72+
self._cuts = value
73+
else:
74+
raise ValueError("Cuts must be an Astropy.visualization Interval object or a tuple of two values.")
75+
7876

7977
@property
8078
def cursor(self) -> str:

src/astro_image_display_api/interface_definition.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from astropy.nddata import NDData
77
from astropy.table import Table
88
from astropy.units import Quantity
9+
from astropy.visualization import BaseInterval, BaseStretch
910

1011
from numpy.typing import ArrayLike
1112

@@ -34,8 +35,8 @@ class ImageViewerInterface(Protocol):
3435
autocut_options: tuple
3536
cursor: str
3637
marker: Any
37-
cuts: Any
38-
stretch: str
38+
cuts: tuple | BaseInterval
39+
stretch: BaseStretch
3940
# viewer: Any
4041

4142
# Allowed locations for cursor display

src/astro_image_display_api/widget_api_test.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from astropy.table import Table, vstack # noqa: E402
1010
from astropy import units as u # noqa: E402
1111
from astropy.wcs import WCS # noqa: E402
12+
from astropy.visualization import AsymmetricPercentileInterval, LogStretch, ManualInterval
1213

1314
__all__ = ['ImageWidgetAPITest']
1415

@@ -279,26 +280,27 @@ def test_stretch(self):
279280
# A bad value should leave the stretch unchanged
280281
assert self.image.stretch is original_stretch
281282

282-
self.image.stretch = 'log'
283+
self.image.stretch = LogStretch()
283284
# A valid value should change the stretch
284285
assert self.image.stretch is not original_stretch
286+
assert isinstance(self.image.stretch, LogStretch)
285287

286288
def test_cuts(self, data):
287-
with pytest.raises(ValueError, match='[mM]ust be one of'):
289+
assert len(self.image.autocut_options) > 0
290+
with pytest.raises(ValueError, match='[mM]ust be'):
288291
self.image.cuts = 'not a valid value'
289292

290-
with pytest.raises(ValueError, match='must have length 2'):
293+
with pytest.raises(ValueError, match='[mM]ust be'):
291294
self.image.cuts = (1, 10, 100)
292295

293-
assert 'histogram' in self.image.autocut_options
294-
295296
# Setting using histogram requires data
296297
self.image.load_image(data)
297-
self.image.cuts = 'histogram'
298-
assert len(self.image.cuts) == 2
298+
self.image.cuts = AsymmetricPercentileInterval(0.1, 99.9)
299+
assert isinstance(self.image.cuts, AsymmetricPercentileInterval)
299300

300301
self.image.cuts = (10, 100)
301-
assert self.image.cuts == (10, 100)
302+
assert isinstance(self.image.cuts, ManualInterval)
303+
assert self.image.cuts.get_limits(data) == (10, 100)
302304

303305
@pytest.mark.skip(reason="Not clear whether colormap is part of the API")
304306
def test_colormap(self):

0 commit comments

Comments
 (0)