Skip to content

Commit e719090

Browse files
authored
Merge pull request astropy#49 from mwcraig/add-colormap-appi
Add colormap api
2 parents 04ebfb1 + 4e885ba commit e719090

File tree

3 files changed

+102
-9
lines changed

3 files changed

+102
-9
lines changed

src/astro_image_display_api/dummy_viewer.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from astropy.coordinates import SkyCoord
1111
from astropy.nddata import CCDData, NDData
1212
from astropy.table import Table, vstack
13-
from astropy.units import Quantity, get_physical_type
13+
from astropy.units import Quantity
1414
from astropy.wcs import WCS
1515
from astropy.wcs.utils import proj_plane_pixel_scales
1616
from astropy.visualization import AsymmetricPercentileInterval, BaseInterval, BaseStretch, LinearStretch, ManualInterval
@@ -37,6 +37,7 @@ class ViewportInfo:
3737
largest_dimension: int | None = None
3838
stretch: BaseStretch | None = None
3939
cuts: BaseInterval | tuple[numbers.Real, numbers.Real] | None = None
40+
colormap: str | None = None
4041

4142
@dataclass
4243
class ImageViewer:
@@ -150,6 +151,22 @@ def set_cuts(self, value: tuple[numbers.Real, numbers.Real] | BaseInterval, imag
150151
raise ValueError(f"Image label '{image_label}' not found. Please load an image first.")
151152
self._images[image_label].cuts = self._cuts
152153

154+
def set_colormap(self, map_name: str, image_label: str | None = None) -> None:
155+
image_label = self._resolve_image_label(image_label)
156+
if image_label not in self._images:
157+
raise ValueError(f"Image label '{image_label}' not found. Please load an image first.")
158+
self._images[image_label].colormap = map_name
159+
160+
set_colormap.__doc__ = ImageViewerInterface.set_colormap.__doc__
161+
162+
def get_colormap(self, image_label: str | None = None) -> str:
163+
image_label = self._resolve_image_label(image_label)
164+
if image_label not in self._images:
165+
raise ValueError(f"Image label '{image_label}' not found. Please load an image first.")
166+
return self._images[image_label].colormap
167+
168+
get_colormap.__doc__ = ImageViewerInterface.get_colormap.__doc__
169+
153170
@property
154171
def cursor(self) -> str:
155172
return self._cursor

src/astro_image_display_api/interface_definition.py

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,14 @@
33
import os
44

55
from astropy.coordinates import SkyCoord
6-
from astropy.nddata import NDData
76
from astropy.table import Table
87
from astropy.units import Quantity
98
from astropy.visualization import BaseInterval, BaseStretch
109

11-
from numpy.typing import ArrayLike
1210

1311
# Allowed locations for cursor display
1412
ALLOWED_CURSOR_LOCATIONS = ('top', 'bottom', None)
1513

16-
1714
__all__ = [
1815
'ImageViewerInterface',
1916
]
@@ -130,6 +127,58 @@ def get_stretch(self, image_label: str | None = None) -> BaseStretch:
130127
"""
131128
raise NotImplementedError
132129

130+
@abstractmethod
131+
def set_colormap(self, map_name: str, image_label: str | None = None) -> None:
132+
"""
133+
Set the colormap for the image specified by image_label.
134+
135+
Parameters
136+
----------
137+
map_name : str
138+
The name of the colormap to set. This should be a
139+
valid colormap name from Matplotlib`_;
140+
not all backends will support
141+
all colormaps, so the viewer should handle errors gracefully.
142+
image_label : str, optional
143+
The label of the image to set the colormap for. If not given and there is
144+
only one image loaded, the colormap for that image is set. If there are
145+
multiple images and no label is provided, an error is raised.
146+
147+
Raises
148+
------
149+
ValueError
150+
If the `map_name` is not a valid colormap name or if the `image_label`
151+
is not provided when there are multiple images loaded.
152+
153+
.. _Matplotlib: https://matplotlib.org/stable/gallery/color/colormap_reference.html
154+
"""
155+
raise NotImplementedError
156+
157+
@abstractmethod
158+
def get_colormap(self, image_label: str | None = None) -> str:
159+
"""
160+
Get the current colormap for the image.
161+
162+
Parameters
163+
----------
164+
image_label : str, optional
165+
The label of the image to get the colormap for. If not given and there is
166+
only one image loaded, the colormap for that image is returned. If there are
167+
multiple images and no label is provided, an error is raised.
168+
169+
Returns
170+
-------
171+
map_name : str
172+
The name of the current colormap.
173+
174+
Raises
175+
------
176+
ValueError
177+
If the `image_label` is not provided when there are multiple images loaded or if
178+
the `image_label` does not correspond to a loaded image.
179+
"""
180+
raise NotImplementedError
181+
133182
# Saving contents of the view and accessing the view
134183
@abstractmethod
135184
def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:

src/astro_image_display_api/widget_api_test.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
from astropy.coordinates import SkyCoord
88
from astropy.io import fits
99
from astropy.nddata import CCDData, NDData
10-
from astropy.table import Table, vstack
10+
from astropy.table import Table
1111
from astropy import units as u
1212
from astropy.wcs import WCS
1313
from astropy.visualization import AsymmetricPercentileInterval, BaseInterval, BaseStretch, LogStretch, ManualInterval
1414

15+
from .interface_definition import ImageViewerInterface
16+
1517
__all__ = ['ImageWidgetAPITest']
1618

1719

@@ -716,15 +718,40 @@ def test_stretch_cuts_errors(self, data):
716718
with pytest.raises(ValueError, match='[Ii]mage label.*not found'):
717719
self.image.set_cuts((10, 100), image_label='not a valid label')
718720

719-
@pytest.mark.skip(reason="Not clear whether colormap is part of the API")
720-
def test_colormap(self):
721+
def test_set_get_colormap(self, data):
722+
# Check setting and getting with a single image label.
723+
self.image.load_image(data, image_label='test')
721724
cmap_desired = 'gray'
722-
cmap_list = self.image.colormap_options
723-
assert len(cmap_list) > 0 and cmap_desired in cmap_list
724725
self.image.set_colormap(cmap_desired)
726+
assert self.image.get_colormap() == cmap_desired
727+
728+
# Check that the colormap can be set with an image label
729+
new_cmap = "viridis"
730+
self.image.set_colormap(new_cmap, image_label='test')
731+
assert self.image.get_colormap(image_label='test') == new_cmap
732+
733+
def test_set_colormap_errors(self, data):
734+
# Check that setting a colormap raises an error if the colormap
735+
# is not in the list of allowed colormaps.
736+
self.image.load_image(data, image_label='test')
737+
738+
# Check that getting a colormap for an image label that does not exist
739+
with pytest.raises(ValueError, match='[Ii]mage label.*not found'):
740+
self.image.get_colormap(image_label='not a valid label')
741+
742+
# Check that setting a colormap without an image label fails
743+
# when there is more than one image label
744+
self.image.load_image(data, image_label='another test')
745+
with pytest.raises(ValueError, match='Multiple image labels defined'):
746+
self.image.set_colormap('gray')
747+
748+
# Same for getting the colormap without an image label
749+
with pytest.raises(ValueError, match='Multiple image labels defined'):
750+
self.image.get_colormap()
725751

726752
def test_cursor(self):
727753
assert self.image.cursor in self.image.ALLOWED_CURSOR_LOCATIONS
754+
assert set(ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS) == set(self.image.ALLOWED_CURSOR_LOCATIONS)
728755
with pytest.raises(self.cursor_error_classes):
729756
self.image.cursor = 'not a valid option'
730757
self.image.cursor = 'bottom'

0 commit comments

Comments
 (0)