Skip to content

Commit 624b8d3

Browse files
committed
Change cuts/stretch from properties to set/get methods
1 parent 6f6dab2 commit 624b8d3

File tree

3 files changed

+72
-28
lines changed

3 files changed

+72
-28
lines changed

src/astro_image_display_api/dummy_viewer.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from astropy.table import Table, vstack
99
from astropy.units import Quantity, get_physical_type
1010
from astropy.wcs import WCS
11-
from astropy.visualization import BaseInterval, BaseStretch, ManualInterval
11+
from astropy.visualization import AsymmetricPercentileInterval, BaseInterval, BaseStretch, LinearStretch, ManualInterval
1212
from numpy.typing import ArrayLike
1313

1414
from .interface_definition import ImageViewerInterface
@@ -31,8 +31,8 @@ class ImageViewer:
3131
zoom_level: float = 1
3232
_cursor: str = ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS[0]
3333
marker: Any = "marker"
34-
_cuts: str | tuple[float, float] = (0, 1)
35-
_stretch: str = "linear"
34+
_cuts: BaseInterval | tuple[float, float] = AsymmetricPercentileInterval(upper_percentile=95)
35+
_stretch: BaseStretch = LinearStretch
3636
# viewer: Any
3737

3838
# Allowed locations for cursor display
@@ -72,30 +72,25 @@ def click_drag(self, value: bool) -> None:
7272
self._click_drag = value
7373
self._click_center = not value
7474

75-
@property
76-
def stretch(self) -> BaseStretch:
75+
def get_stretch(self) -> BaseStretch:
7776
return self._stretch
7877

79-
@stretch.setter
80-
def stretch(self, value: BaseStretch) -> None:
78+
def set_stretch(self, value: BaseStretch) -> None:
8179
if not isinstance(value, BaseStretch):
8280
raise ValueError(f"Stretch option {value} is not valid. Must be an Astropy.visualization Stretch object.")
8381
self._stretch = value
8482

85-
@property
86-
def cuts(self) -> tuple:
83+
def get_cuts(self) -> tuple:
8784
return self._cuts
8885

89-
@cuts.setter
90-
def cuts(self, value: tuple) -> None:
86+
def set_cuts(self, value: tuple[float, float] | BaseInterval) -> None:
9187
if isinstance(value, tuple) and len(value) == 2:
9288
self._cuts = ManualInterval(value[0], value[1])
9389
elif isinstance(value, BaseInterval):
9490
self._cuts = value
9591
else:
9692
raise ValueError("Cuts must be an Astropy.visualization Interval object or a tuple of two values.")
9793

98-
9994
@property
10095
def cursor(self) -> str:
10196
return self._cursor

src/astro_image_display_api/interface_definition.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,6 @@ class ImageViewerInterface(Protocol):
3636
zoom_level: float
3737
cursor: str
3838
marker: Any
39-
cuts: tuple | BaseInterval
40-
stretch: BaseStretch
41-
# viewer: Any
4239

4340
# Allowed locations for cursor display
4441
ALLOWED_CURSOR_LOCATIONS: tuple = ALLOWED_CURSOR_LOCATIONS
@@ -89,6 +86,58 @@ def load_nddata(self, data: NDData) -> None:
8986
"""
9087
raise NotImplementedError
9188

89+
# Setting and getting image properties
90+
@abstractmethod
91+
def set_cuts(self, cuts: tuple | BaseInterval) -> None:
92+
"""
93+
Set the cuts for the image.
94+
95+
Parameters
96+
----------
97+
cuts : tuple or any Interval from `astropy.visualization`
98+
The cuts to set. If a tuple, it should be of the form
99+
``(min, max)`` and will be interpreted as a
100+
`~astropy.visualization.ManualInterval`.
101+
"""
102+
raise NotImplementedError
103+
104+
@abstractmethod
105+
def get_cuts(self) -> BaseInterval:
106+
"""
107+
Get the current cuts for the image.
108+
109+
Returns
110+
-------
111+
cuts : `~astropy.visualization.BaseInterval`
112+
The Astropy interval object representing the current cuts.
113+
"""
114+
raise NotImplementedError
115+
116+
@abstractmethod
117+
def set_stretch(self, stretch: BaseStretch) -> None:
118+
"""
119+
Set the stretch for the image.
120+
121+
Parameters
122+
----------
123+
stretch : Any stretch from `~astropy.visualization`
124+
The stretch to set. This can be any subclass of
125+
`~astropy.visualization.BaseStretch`.
126+
"""
127+
raise NotImplementedError
128+
129+
@abstractmethod
130+
def get_stretch(self) -> BaseStretch:
131+
"""
132+
Get the current stretch for the image.
133+
134+
Returns
135+
-------
136+
stretch : `~astropy.visualization.BaseStretch`
137+
The Astropy stretch object representing the current stretch.
138+
"""
139+
raise NotImplementedError
140+
92141
# Saving contents of the view and accessing the view
93142
@abstractmethod
94143
def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:

src/astro_image_display_api/widget_api_test.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -270,34 +270,34 @@ def test_adding_markers_as_world(self, data, wcs):
270270
mark_coord_table['coord'].dec.deg)
271271

272272
def test_stretch(self):
273-
original_stretch = self.image.stretch
273+
original_stretch = self.image.get_stretch()
274274

275275
with pytest.raises(ValueError, match=r'Stretch.*not valid.*'):
276-
self.image.stretch = 'not a valid value'
276+
self.image.set_stretch('not a valid value')
277277

278278
# A bad value should leave the stretch unchanged
279-
assert self.image.stretch is original_stretch
279+
assert self.image.get_stretch() is original_stretch
280280

281-
self.image.stretch = LogStretch()
281+
self.image.set_stretch(LogStretch())
282282
# A valid value should change the stretch
283-
assert self.image.stretch is not original_stretch
284-
assert isinstance(self.image.stretch, LogStretch)
283+
assert self.image.get_stretch() is not original_stretch
284+
assert isinstance(self.image.get_stretch(), LogStretch)
285285

286286
def test_cuts(self, data):
287287
with pytest.raises(ValueError, match='[mM]ust be'):
288-
self.image.cuts = 'not a valid value'
288+
self.image.set_cuts('not a valid value')
289289

290290
with pytest.raises(ValueError, match='[mM]ust be'):
291-
self.image.cuts = (1, 10, 100)
291+
self.image.set_cuts((1, 10, 100))
292292

293293
# Setting using histogram requires data
294294
self.image.load_array(data)
295-
self.image.cuts = AsymmetricPercentileInterval(0.1, 99.9)
296-
assert isinstance(self.image.cuts, AsymmetricPercentileInterval)
295+
self.image.set_cuts(AsymmetricPercentileInterval(0.1, 99.9))
296+
assert isinstance(self.image.get_cuts(), AsymmetricPercentileInterval)
297297

298-
self.image.cuts = (10, 100)
299-
assert isinstance(self.image.cuts, ManualInterval)
300-
assert self.image.cuts.get_limits(data) == (10, 100)
298+
self.image.set_cuts((10, 100))
299+
assert isinstance(self.image.get_cuts(), ManualInterval)
300+
assert self.image.get_cuts().get_limits(data) == (10, 100)
301301

302302
@pytest.mark.skip(reason="Not clear whether colormap is part of the API")
303303
def test_colormap(self):

0 commit comments

Comments
 (0)