Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/astro_image_display_api/dummy_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_stretch(self, image_label: str | None = None) -> BaseStretch:

def set_stretch(self, value: BaseStretch, image_label: str | None = None) -> None:
if not isinstance(value, BaseStretch):
raise ValueError(f"Stretch option {value} is not valid. Must be an Astropy.visualization Stretch object.")
raise TypeError(f"Stretch option {value} is not valid. Must be an Astropy.visualization Stretch object.")
image_label = self._resolve_image_label(image_label)
if image_label not in self._images:
raise ValueError(f"Image label '{image_label}' not found. Please load an image first.")
Expand All @@ -145,7 +145,7 @@ def set_cuts(self, value: tuple[numbers.Real, numbers.Real] | BaseInterval, imag
elif isinstance(value, BaseInterval):
self._cuts = value
else:
raise ValueError("Cuts must be an Astropy.visualization Interval object or a tuple of two values.")
raise TypeError("Cuts must be an Astropy.visualization Interval object or a tuple of two values.")
image_label = self._resolve_image_label(image_label)
if image_label not in self._images:
raise ValueError(f"Image label '{image_label}' not found. Please load an image first.")
Expand Down Expand Up @@ -498,7 +498,7 @@ def remove_catalog(self, catalog_label: str | None = None) -> None:
then all markers will be removed.
"""
if isinstance(catalog_label, list):
raise ValueError(
raise TypeError(
"Cannot remove multiple catalogs from a list. Please specify "
"a single catalog label or use '*' to remove all catalogs."
)
Expand Down
83 changes: 78 additions & 5 deletions src/astro_image_display_api/interface_definition.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Protocol, runtime_checkable, Any
from abc import abstractmethod
import numbers
import os

from astropy.coordinates import SkyCoord
Expand Down Expand Up @@ -55,7 +56,7 @@ def load_image(self, data: Any, image_label: str | None = None) -> None:

# Setting and getting image properties
@abstractmethod
def set_cuts(self, cuts: tuple | BaseInterval, image_label: str | None = None) -> None:
def set_cuts(self, cuts: tuple[numbers.Real, numbers.Real] | BaseInterval, image_label: str | None = None) -> None:
"""
Set the cuts for the image.

Expand All @@ -69,6 +70,16 @@ def set_cuts(self, cuts: tuple | BaseInterval, image_label: str | None = None) -
image_label : str, optional
The label of the image to set the cuts for. If not given and there is
only one image loaded, the cuts for that image are set.

Raises
------
TypeError
If the `cuts` parameter is not a tuple or an `astropy.visualization.BaseInterval`
object.

ValueError
If the `image_label` is not provided when there are multiple images loaded,
or if the `image_label` does not correspond to a loaded image.
"""
raise NotImplementedError

Expand All @@ -88,6 +99,12 @@ def get_cuts(self, image_label: str | None = None) -> BaseInterval:
-------
cuts : `~astropy.visualization.BaseInterval`
The Astropy interval object representing the current cuts.

Raises
------
ValueError
If the `image_label` is not provided when there are multiple images loaded,
or if the `image_label` does not correspond to a loaded image.
"""
raise NotImplementedError

Expand All @@ -105,6 +122,15 @@ def set_stretch(self, stretch: BaseStretch, image_label: str | None = None) -> N
image_label : str, optional
The label of the image to set the cuts for. If not given and there is
only one image loaded, the cuts for that image are set.

Raises
------
TypeError
If the `stretch` is not a valid `BaseStretch` object

ValueError
if the `image_label` is not provided when there are multiple images loaded
or if the `image_label` does not correspond to a loaded image.
"""
raise NotImplementedError

Expand Down Expand Up @@ -194,6 +220,11 @@ def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:
overwrite : bool, optional
If `True`, overwrite the file if it exists. Default is
`False`.

Raises
------
FileExistsError
If the file already exists and `overwrite` is `False`.
"""
raise NotImplementedError

Expand Down Expand Up @@ -229,6 +260,13 @@ def load_catalog(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
A dictionary that specifies the style of the markers used to
represent the catalog. See `ImageViewerInterface.set_catalog_style`
for details.

Raises
------
ValueError
If the `table` does not contain the required columns, or if
the `catalog_label` is not provided when there are multiple
catalogs loaded.
"""
raise NotImplementedError

Expand Down Expand Up @@ -262,6 +300,13 @@ def set_catalog_style(
-----
The following shapes are supported: "circle", "square", "crosshair", "plus",
"diamond".

Raises
------
ValueError
If there are multiple catalog styles set and the user has not
specified a `catalog_label` for which to set the style, or if
an style is set for a catalog that does not exist.
"""
raise NotImplementedError

Expand All @@ -270,6 +315,15 @@ def get_catalog_style(self, catalog_label: str | None = None) -> dict:
"""
Get the style of the catalog markers.

Parameters
----------
catalog_label : str, optional
The name of the catalog. If not given and there is
only one catalog loaded, the style for that catalog is returned.
If there are multiple catalogs and no label is provided, an error
is raised. If the label does not correspond to a loaded
catalog, an empty dictionary is returned.

Returns
-------
dict
Expand All @@ -281,19 +335,32 @@ def get_catalog_style(self, catalog_label: str | None = None) -> dict:
ValueError
If there are multiple catalog styles set and the user has not
specified a `catalog_label` for which to get the style.

"""
raise NotImplementedError

@abstractmethod
def remove_catalog(self, catalog_label: str | list[str] | None = None) -> None:
def remove_catalog(self, catalog_label: str | None = None) -> None:
"""
Remove markers from the image.

Parameters
----------
catalog_label : str, optional
The name of the marker set to remove. If the value is ``"all"``,
then all markers will be removed.
The name of the catalog to remove. The value ``'*'`` can be used to
remove all catalogs. If not given and there is
only one catalog loaded, that catalog is removed.

Raises
------
ValueError
If the `catalog_label` is not provided when there are multiple
catalogs loaded, or if the `catalog_label` does not correspond to a
loaded catalog.

TypeError
If the `catalog_label` is not a string or `None`, or if it is not
one of the allowed values.
"""
raise NotImplementedError

Expand Down Expand Up @@ -324,6 +391,11 @@ def get_catalog(self, x_colname: str = 'x', y_colname: str = 'y',
table : `astropy.table.Table`
The table containing the marker positions. If no markers match the
``catalog_label`` parameter, an empty table is returned.

Raises
------
ValueError
If the `catalog_label` is not provided when there are multiple catalogs loaded.
"""
raise NotImplementedError

Expand Down Expand Up @@ -408,6 +480,7 @@ def get_viewport(self, sky_or_pixel: str | None = None, image_label: str | None
-------
ValueError
If the `sky_or_pixel` parameter is not one of 'sky', 'pixel', or `None`, or if
the `image_label` is not provided when there are multiple images loaded.
the `image_label` is not provided when there are multiple images loaded, or if
the `image_label` does not correspond to a loaded image.
"""
raise NotImplementedError
8 changes: 4 additions & 4 deletions src/astro_image_display_api/widget_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def test_remove_catalog_does_not_accept_list(self):
self.image.load_catalog(tab, catalog_label='test2', use_skycoord=False)

with pytest.raises(
ValueError,
TypeError,
match='Cannot remove multiple catalogs from a list'
):
self.image.remove_catalog(catalog_label=['test1', 'test2'])
Expand Down Expand Up @@ -646,7 +646,7 @@ def test_adding_catalog_as_world(self, data, wcs):
def test_stretch(self):
original_stretch = self.image.get_stretch()

with pytest.raises(ValueError, match=r'Stretch.*not valid.*'):
with pytest.raises(TypeError, match=r'Stretch.*not valid.*'):
self.image.set_stretch('not a valid value')

# A bad value should leave the stretch unchanged
Expand All @@ -658,10 +658,10 @@ def test_stretch(self):
assert isinstance(self.image.get_stretch(), LogStretch)

def test_cuts(self, data):
with pytest.raises(ValueError, match='[mM]ust be'):
with pytest.raises(TypeError, match='[mM]ust be'):
self.image.set_cuts('not a valid value')

with pytest.raises(ValueError, match='[mM]ust be'):
with pytest.raises(TypeError, match='[mM]ust be'):
self.image.set_cuts((1, 10, 100))

# Setting using histogram requires data
Expand Down