Skip to content

Commit 51d7582

Browse files
authored
Update allowed input types for stretch and cuts (astropy#44)
* Update allowed input types for stretch and cuts * Remove autocut_options and stretch_options * Change cuts/stretch from properties to set/get methods
1 parent 425ae4c commit 51d7582

File tree

3 files changed

+85
-48
lines changed

3 files changed

+85
-48
lines changed

src/astro_image_display_api/dummy_viewer.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from astropy.table import Table, vstack
99
from astropy.units import Quantity, get_physical_type
1010
from astropy.wcs import WCS
11+
from astropy.visualization import AsymmetricPercentileInterval, BaseInterval, BaseStretch, LinearStretch, ManualInterval
1112
from numpy.typing import ArrayLike
1213

1314
from .interface_definition import ImageViewerInterface
@@ -25,12 +26,10 @@ class ImageViewer:
2526
image_width: int = 0
2627
image_height: int = 0
2728
zoom_level: float = 1
28-
stretch_options: tuple = ("linear", "log", "sqrt")
29-
autocut_options: tuple = ("minmax", "zscale", "asinh", "percentile", "histogram")
3029
_cursor: str = ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS[0]
3130
marker: Any = "marker"
32-
_cuts: str | tuple[float, float] = (0, 1)
33-
_stretch: str = "linear"
31+
_cuts: BaseInterval | tuple[float, float] = AsymmetricPercentileInterval(upper_percentile=95)
32+
_stretch: BaseStretch = LinearStretch
3433
# viewer: Any
3534

3635
# Allowed locations for cursor display
@@ -49,32 +48,24 @@ class ImageViewer:
4948
_wcs: WCS | None = None
5049
_center: tuple[float, float] = (0.0, 0.0)
5150

52-
@property
53-
def stretch(self) -> str:
51+
def get_stretch(self) -> BaseStretch:
5452
return self._stretch
5553

56-
@stretch.setter
57-
def stretch(self, value: str) -> None:
58-
if value not in self.stretch_options:
59-
raise ValueError(f"Stretch option {value} is not valid. Must be one of {self.stretch_options}.")
54+
def set_stretch(self, value: BaseStretch) -> None:
55+
if not isinstance(value, BaseStretch):
56+
raise ValueError(f"Stretch option {value} is not valid. Must be an Astropy.visualization Stretch object.")
6057
self._stretch = value
6158

62-
@property
63-
def cuts(self) -> tuple:
59+
def get_cuts(self) -> tuple:
6460
return self._cuts
6561

66-
@cuts.setter
67-
def cuts(self, value: tuple) -> None:
68-
if isinstance(value, str):
69-
if value not in self.autocut_options:
70-
raise ValueError(f"Cut option {value} is not valid. Must be one of {self.autocut_options}.")
71-
# A real viewer would calculate the cuts based on the data
72-
self._cuts = (0, 1)
73-
return
74-
75-
if len(value) != 2:
76-
raise ValueError("Cuts must have length 2.")
77-
self._cuts = value
62+
def set_cuts(self, value: tuple[float, float] | BaseInterval) -> None:
63+
if isinstance(value, tuple) and len(value) == 2:
64+
self._cuts = ManualInterval(value[0], value[1])
65+
elif isinstance(value, BaseInterval):
66+
self._cuts = value
67+
else:
68+
raise ValueError("Cuts must be an Astropy.visualization Interval object or a tuple of two values.")
7869

7970
@property
8071
def cursor(self) -> str:

src/astro_image_display_api/interface_definition.py

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from astropy.nddata import NDData
77
from astropy.table import Table
88
from astropy.units import Quantity
9+
from astropy.visualization import BaseInterval, BaseStretch
910

1011
from numpy.typing import ArrayLike
1112

@@ -30,13 +31,8 @@ class ImageViewerInterface(Protocol):
3031
image_width: int
3132
image_height: int
3233
zoom_level: float
33-
stretch_options: tuple
34-
autocut_options: tuple
3534
cursor: str
3635
marker: Any
37-
cuts: Any
38-
stretch: str
39-
# viewer: Any
4036

4137
# Allowed locations for cursor display
4238
ALLOWED_CURSOR_LOCATIONS: tuple = ALLOWED_CURSOR_LOCATIONS
@@ -65,6 +61,58 @@ def load_image(self, data: Any) -> None:
6561
"""
6662
raise NotImplementedError
6763

64+
# Setting and getting image properties
65+
@abstractmethod
66+
def set_cuts(self, cuts: tuple | BaseInterval) -> None:
67+
"""
68+
Set the cuts for the image.
69+
70+
Parameters
71+
----------
72+
cuts : tuple or any Interval from `astropy.visualization`
73+
The cuts to set. If a tuple, it should be of the form
74+
``(min, max)`` and will be interpreted as a
75+
`~astropy.visualization.ManualInterval`.
76+
"""
77+
raise NotImplementedError
78+
79+
@abstractmethod
80+
def get_cuts(self) -> BaseInterval:
81+
"""
82+
Get the current cuts for the image.
83+
84+
Returns
85+
-------
86+
cuts : `~astropy.visualization.BaseInterval`
87+
The Astropy interval object representing the current cuts.
88+
"""
89+
raise NotImplementedError
90+
91+
@abstractmethod
92+
def set_stretch(self, stretch: BaseStretch) -> None:
93+
"""
94+
Set the stretch for the image.
95+
96+
Parameters
97+
----------
98+
stretch : Any stretch from `~astropy.visualization`
99+
The stretch to set. This can be any subclass of
100+
`~astropy.visualization.BaseStretch`.
101+
"""
102+
raise NotImplementedError
103+
104+
@abstractmethod
105+
def get_stretch(self) -> BaseStretch:
106+
"""
107+
Get the current stretch for the image.
108+
109+
Returns
110+
-------
111+
stretch : `~astropy.visualization.BaseStretch`
112+
The Astropy stretch object representing the current stretch.
113+
"""
114+
raise NotImplementedError
115+
68116
# Saving contents of the view and accessing the view
69117
@abstractmethod
70118
def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:

src/astro_image_display_api/widget_api_test.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from astropy.table import Table, vstack # noqa: E402
1010
from astropy import units as u # noqa: E402
1111
from astropy.wcs import WCS # noqa: E402
12+
from astropy.visualization import AsymmetricPercentileInterval, LogStretch, ManualInterval
1213

1314
__all__ = ['ImageWidgetAPITest']
1415

@@ -268,37 +269,34 @@ def test_adding_markers_as_world(self, data, wcs):
268269
mark_coord_table['coord'].dec.deg)
269270

270271
def test_stretch(self):
271-
# Check that the stretch options is not an empty list
272-
assert len(self.image.stretch_options) > 0
272+
original_stretch = self.image.get_stretch()
273273

274-
original_stretch = self.image.stretch
275-
276-
with pytest.raises(ValueError, match='[mM]ust be one of'):
277-
self.image.stretch = 'not a valid value'
274+
with pytest.raises(ValueError, match=r'Stretch.*not valid.*'):
275+
self.image.set_stretch('not a valid value')
278276

279277
# A bad value should leave the stretch unchanged
280-
assert self.image.stretch is original_stretch
278+
assert self.image.get_stretch() is original_stretch
281279

282-
self.image.stretch = 'log'
280+
self.image.set_stretch(LogStretch())
283281
# A valid value should change the stretch
284-
assert self.image.stretch is not original_stretch
282+
assert self.image.get_stretch() is not original_stretch
283+
assert isinstance(self.image.get_stretch(), LogStretch)
285284

286285
def test_cuts(self, data):
287-
with pytest.raises(ValueError, match='[mM]ust be one of'):
288-
self.image.cuts = 'not a valid value'
289-
290-
with pytest.raises(ValueError, match='must have length 2'):
291-
self.image.cuts = (1, 10, 100)
286+
with pytest.raises(ValueError, match='[mM]ust be'):
287+
self.image.set_cuts('not a valid value')
292288

293-
assert 'histogram' in self.image.autocut_options
289+
with pytest.raises(ValueError, match='[mM]ust be'):
290+
self.image.set_cuts((1, 10, 100))
294291

295292
# Setting using histogram requires data
296293
self.image.load_image(data)
297-
self.image.cuts = 'histogram'
298-
assert len(self.image.cuts) == 2
294+
self.image.set_cuts(AsymmetricPercentileInterval(0.1, 99.9))
295+
assert isinstance(self.image.get_cuts(), AsymmetricPercentileInterval)
299296

300-
self.image.cuts = (10, 100)
301-
assert self.image.cuts == (10, 100)
297+
self.image.set_cuts((10, 100))
298+
assert isinstance(self.image.get_cuts(), ManualInterval)
299+
assert self.image.get_cuts().get_limits(data) == (10, 100)
302300

303301
@pytest.mark.skip(reason="Not clear whether colormap is part of the API")
304302
def test_colormap(self):

0 commit comments

Comments
 (0)