Skip to content

Commit 21660b9

Browse files
committed
Add colormap API to interface and tests
1 parent 04ebfb1 commit 21660b9

File tree

2 files changed

+123
-6
lines changed

2 files changed

+123
-6
lines changed

src/astro_image_display_api/interface_definition.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,23 @@
1212

1313
# Allowed locations for cursor display
1414
ALLOWED_CURSOR_LOCATIONS = ('top', 'bottom', None)
15-
15+
MINIMUM_REQUIRED_COLORMAPS = (
16+
'gray',
17+
'viridis',
18+
'plasma',
19+
'inferno',
20+
'magma',
21+
'purple-blue',
22+
'yellow-green-blue',
23+
'yellow-orange-red',
24+
'red-purple',
25+
'blue-green',
26+
'hot',
27+
'red-blue',
28+
'red-yellow-blue',
29+
'purple-orange'
30+
'purple-green',
31+
)
1632

1733
__all__ = [
1834
'ImageViewerInterface',
@@ -31,6 +47,9 @@ class ImageViewerInterface(Protocol):
3147
# Allowed locations for cursor display
3248
ALLOWED_CURSOR_LOCATIONS: tuple = ALLOWED_CURSOR_LOCATIONS
3349

50+
# Required colormaps for the viewer
51+
MINIMUM_REQUIRED_COLORMAPS: tuple[str, ...] = MINIMUM_REQUIRED_COLORMAPS
52+
3453
# The methods, grouped loosely by purpose
3554

3655
# Method for loading image data
@@ -130,6 +149,69 @@ def get_stretch(self, image_label: str | None = None) -> BaseStretch:
130149
"""
131150
raise NotImplementedError
132151

152+
@abstractmethod
153+
def set_colormap(self, map_name: str, image_label: str | None = None) -> None:
154+
"""
155+
Set the colormap for the image specified by image_label.
156+
157+
Parameters
158+
----------
159+
map_name : str
160+
The name of the colormap to set. This should be a valid
161+
colormap name from Matplotlib; not all backends will support
162+
all colormaps, so the viewer should handle errors gracefully.
163+
The case of the `map_name` is not important.
164+
image_label : str, optional
165+
The label of the image to set the colormap for. If not given and there is
166+
only one image loaded, the colormap for that image is set. If there are
167+
multiple images and no label is provided, an error is raised.
168+
169+
Raises
170+
------
171+
ValueError
172+
If the `map_name` is not a valid colormap name or if the `image_label`
173+
is not provided when there are multiple images loaded.
174+
"""
175+
raise NotImplementedError
176+
177+
@abstractmethod
178+
def get_colormap(self, image_label: str | None = None) -> str:
179+
"""
180+
Get the current colormap for the image.
181+
182+
Parameters
183+
----------
184+
image_label : str, optional
185+
The label of the image to get the colormap for. If not given and there is
186+
only one image loaded, the colormap for that image is returned. If there are
187+
multiple images and no label is provided, an error is raised.
188+
189+
Returns
190+
-------
191+
map_name : str
192+
The name of the current colormap.
193+
194+
Raises
195+
------
196+
ValueError
197+
If the `image_label` is not provided when there are multiple images loaded or if
198+
the `image_label` does not correspond to a loaded image.
199+
"""
200+
raise NotImplementedError
201+
202+
@property
203+
@abstractmethod
204+
def colormap_options(self) -> list[str]:
205+
"""
206+
Get the list of available colormaps.
207+
208+
Returns
209+
-------
210+
list of str
211+
A list of available colormap names.
212+
"""
213+
raise NotImplementedError
214+
133215
# Saving contents of the view and accessing the view
134216
@abstractmethod
135217
def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:

src/astro_image_display_api/widget_api_test.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
from astropy.coordinates import SkyCoord
88
from astropy.io import fits
99
from astropy.nddata import CCDData, NDData
10-
from astropy.table import Table, vstack
10+
from astropy.table import Table
1111
from astropy import units as u
1212
from astropy.wcs import WCS
1313
from astropy.visualization import AsymmetricPercentileInterval, BaseInterval, BaseStretch, LogStretch, ManualInterval
1414

15+
from .interface_definition import ImageViewerInterface
16+
1517
__all__ = ['ImageWidgetAPITest']
1618

1719

@@ -716,15 +718,48 @@ def test_stretch_cuts_errors(self, data):
716718
with pytest.raises(ValueError, match='[Ii]mage label.*not found'):
717719
self.image.set_cuts((10, 100), image_label='not a valid label')
718720

719-
@pytest.mark.skip(reason="Not clear whether colormap is part of the API")
720-
def test_colormap(self):
721-
cmap_desired = 'gray'
721+
def test_colormap_options(self):
722722
cmap_list = self.image.colormap_options
723-
assert len(cmap_list) > 0 and cmap_desired in cmap_list
723+
assert set(ImageViewerInterface.MINIMUM_REQUIRED_COLORMAPS) <= set(cmap_list)
724+
assert set(self.image.MINIMUM_REQUIRED_COLORMAPS) == set(ImageViewerInterface.MINIMUM_REQUIRED_COLORMAPS)
725+
726+
def test_set_get_colormap(self, data):
727+
# Check setting and getting with a single image label.
728+
self.image.load_image(data, image_label='test')
729+
cmap_desired = 'gray'
724730
self.image.set_colormap(cmap_desired)
731+
assert self.image.get_colormap() == cmap_desired
732+
733+
# Check that the colormap can be set with an image label
734+
new_cmap = "viridis"
735+
self.image.set_colormap(new_cmap, image_label='test')
736+
assert self.image.get_colormap(image_label='test') == new_cmap
737+
738+
def test_set_colormap_errors(self, data):
739+
# Check that setting a colormap raises an error if the colormap
740+
# is not in the list of allowed colormaps.
741+
self.image.load_image(data, image_label='test')
742+
743+
with pytest.raises(ValueError, match='[Ii]nvalid colormap'):
744+
self.image.set_colormap('not a valid colormap')
745+
746+
# Check that getting a colormap for an image label that does not exist
747+
with pytest.raises(ValueError, match='[Ii]mage label.*not found'):
748+
self.image.get_colormap(image_label='not a valid label')
749+
750+
# Check that setting a colormap without an image label fails
751+
# when there is more than one image label
752+
self.image.load_image(data, image_label='another test')
753+
with pytest.raises(ValueError, match='Multiple image labels defined'):
754+
self.image.set_colormap('gray')
755+
756+
# Same for getting the colormap without an image label
757+
with pytest.raises(ValueError, match='Multiple image labels defined'):
758+
self.image.get_colormap()
725759

726760
def test_cursor(self):
727761
assert self.image.cursor in self.image.ALLOWED_CURSOR_LOCATIONS
762+
assert set(ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS) == set(self.image.ALLOWED_CURSOR_LOCATIONS)
728763
with pytest.raises(self.cursor_error_classes):
729764
self.image.cursor = 'not a valid option'
730765
self.image.cursor = 'bottom'

0 commit comments

Comments
 (0)