Skip to content

Commit fd33ab7

Browse files
committed
Change cuts/stretch from properties to set/get methods
1 parent e20be84 commit fd33ab7

File tree

3 files changed

+72
-28
lines changed

3 files changed

+72
-28
lines changed

src/astro_image_display_api/dummy_viewer.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from astropy.table import Table, vstack
99
from astropy.units import Quantity, get_physical_type
1010
from astropy.wcs import WCS
11-
from astropy.visualization import BaseInterval, BaseStretch, ManualInterval
11+
from astropy.visualization import AsymmetricPercentileInterval, BaseInterval, BaseStretch, LinearStretch, ManualInterval
1212
from numpy.typing import ArrayLike
1313

1414
from .interface_definition import ImageViewerInterface
@@ -28,8 +28,8 @@ class ImageViewer:
2828
zoom_level: float = 1
2929
_cursor: str = ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS[0]
3030
marker: Any = "marker"
31-
_cuts: str | tuple[float, float] = (0, 1)
32-
_stretch: str = "linear"
31+
_cuts: BaseInterval | tuple[float, float] = AsymmetricPercentileInterval(upper_percentile=95)
32+
_stretch: BaseStretch = LinearStretch
3333
# viewer: Any
3434

3535
# Allowed locations for cursor display
@@ -48,30 +48,25 @@ class ImageViewer:
4848
_wcs: WCS | None = None
4949
_center: tuple[float, float] = (0.0, 0.0)
5050

51-
@property
52-
def stretch(self) -> BaseStretch:
51+
def get_stretch(self) -> BaseStretch:
5352
return self._stretch
5453

55-
@stretch.setter
56-
def stretch(self, value: BaseStretch) -> None:
54+
def set_stretch(self, value: BaseStretch) -> None:
5755
if not isinstance(value, BaseStretch):
5856
raise ValueError(f"Stretch option {value} is not valid. Must be an Astropy.visualization Stretch object.")
5957
self._stretch = value
6058

61-
@property
62-
def cuts(self) -> tuple:
59+
def get_cuts(self) -> tuple:
6360
return self._cuts
6461

65-
@cuts.setter
66-
def cuts(self, value: tuple) -> None:
62+
def set_cuts(self, value: tuple[float, float] | BaseInterval) -> None:
6763
if isinstance(value, tuple) and len(value) == 2:
6864
self._cuts = ManualInterval(value[0], value[1])
6965
elif isinstance(value, BaseInterval):
7066
self._cuts = value
7167
else:
7268
raise ValueError("Cuts must be an Astropy.visualization Interval object or a tuple of two values.")
7369

74-
7570
@property
7671
def cursor(self) -> str:
7772
return self._cursor

src/astro_image_display_api/interface_definition.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ class ImageViewerInterface(Protocol):
3333
zoom_level: float
3434
cursor: str
3535
marker: Any
36-
cuts: tuple | BaseInterval
37-
stretch: BaseStretch
38-
# viewer: Any
3936

4037
# Allowed locations for cursor display
4138
ALLOWED_CURSOR_LOCATIONS: tuple = ALLOWED_CURSOR_LOCATIONS
@@ -64,6 +61,58 @@ def load_image(self, data: Any) -> None:
6461
"""
6562
raise NotImplementedError
6663

64+
# Setting and getting image properties
65+
@abstractmethod
66+
def set_cuts(self, cuts: tuple | BaseInterval) -> None:
67+
"""
68+
Set the cuts for the image.
69+
70+
Parameters
71+
----------
72+
cuts : tuple or any Interval from `astropy.visualization`
73+
The cuts to set. If a tuple, it should be of the form
74+
``(min, max)`` and will be interpreted as a
75+
`~astropy.visualization.ManualInterval`.
76+
"""
77+
raise NotImplementedError
78+
79+
@abstractmethod
80+
def get_cuts(self) -> BaseInterval:
81+
"""
82+
Get the current cuts for the image.
83+
84+
Returns
85+
-------
86+
cuts : `~astropy.visualization.BaseInterval`
87+
The Astropy interval object representing the current cuts.
88+
"""
89+
raise NotImplementedError
90+
91+
@abstractmethod
92+
def set_stretch(self, stretch: BaseStretch) -> None:
93+
"""
94+
Set the stretch for the image.
95+
96+
Parameters
97+
----------
98+
stretch : Any stretch from `~astropy.visualization`
99+
The stretch to set. This can be any subclass of
100+
`~astropy.visualization.BaseStretch`.
101+
"""
102+
raise NotImplementedError
103+
104+
@abstractmethod
105+
def get_stretch(self) -> BaseStretch:
106+
"""
107+
Get the current stretch for the image.
108+
109+
Returns
110+
-------
111+
stretch : `~astropy.visualization.BaseStretch`
112+
The Astropy stretch object representing the current stretch.
113+
"""
114+
raise NotImplementedError
115+
67116
# Saving contents of the view and accessing the view
68117
@abstractmethod
69118
def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:

src/astro_image_display_api/widget_api_test.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -269,34 +269,34 @@ def test_adding_markers_as_world(self, data, wcs):
269269
mark_coord_table['coord'].dec.deg)
270270

271271
def test_stretch(self):
272-
original_stretch = self.image.stretch
272+
original_stretch = self.image.get_stretch()
273273

274274
with pytest.raises(ValueError, match=r'Stretch.*not valid.*'):
275-
self.image.stretch = 'not a valid value'
275+
self.image.set_stretch('not a valid value')
276276

277277
# A bad value should leave the stretch unchanged
278-
assert self.image.stretch is original_stretch
278+
assert self.image.get_stretch() is original_stretch
279279

280-
self.image.stretch = LogStretch()
280+
self.image.set_stretch(LogStretch())
281281
# A valid value should change the stretch
282-
assert self.image.stretch is not original_stretch
283-
assert isinstance(self.image.stretch, LogStretch)
282+
assert self.image.get_stretch() is not original_stretch
283+
assert isinstance(self.image.get_stretch(), LogStretch)
284284

285285
def test_cuts(self, data):
286286
with pytest.raises(ValueError, match='[mM]ust be'):
287-
self.image.cuts = 'not a valid value'
287+
self.image.set_cuts('not a valid value')
288288

289289
with pytest.raises(ValueError, match='[mM]ust be'):
290-
self.image.cuts = (1, 10, 100)
290+
self.image.set_cuts((1, 10, 100))
291291

292292
# Setting using histogram requires data
293293
self.image.load_image(data)
294-
self.image.cuts = AsymmetricPercentileInterval(0.1, 99.9)
295-
assert isinstance(self.image.cuts, AsymmetricPercentileInterval)
294+
self.image.set_cuts(AsymmetricPercentileInterval(0.1, 99.9))
295+
assert isinstance(self.image.get_cuts(), AsymmetricPercentileInterval)
296296

297-
self.image.cuts = (10, 100)
298-
assert isinstance(self.image.cuts, ManualInterval)
299-
assert self.image.cuts.get_limits(data) == (10, 100)
297+
self.image.set_cuts((10, 100))
298+
assert isinstance(self.image.get_cuts(), ManualInterval)
299+
assert self.image.get_cuts().get_limits(data) == (10, 100)
300300

301301
@pytest.mark.skip(reason="Not clear whether colormap is part of the API")
302302
def test_colormap(self):

0 commit comments

Comments
 (0)