diff --git a/src/astro_image_display_api/dummy_viewer.py b/src/astro_image_display_api/dummy_viewer.py index dfc848c..ecf4bde 100644 --- a/src/astro_image_display_api/dummy_viewer.py +++ b/src/astro_image_display_api/dummy_viewer.py @@ -127,7 +127,7 @@ def get_stretch(self, image_label: str | None = None) -> BaseStretch: def set_stretch(self, value: BaseStretch, image_label: str | None = None) -> None: if not isinstance(value, BaseStretch): - raise ValueError(f"Stretch option {value} is not valid. Must be an Astropy.visualization Stretch object.") + raise TypeError(f"Stretch option {value} is not valid. Must be an Astropy.visualization Stretch object.") image_label = self._resolve_image_label(image_label) if image_label not in self._images: raise ValueError(f"Image label '{image_label}' not found. Please load an image first.") @@ -145,7 +145,7 @@ def set_cuts(self, value: tuple[numbers.Real, numbers.Real] | BaseInterval, imag elif isinstance(value, BaseInterval): self._cuts = value else: - raise ValueError("Cuts must be an Astropy.visualization Interval object or a tuple of two values.") + raise TypeError("Cuts must be an Astropy.visualization Interval object or a tuple of two values.") image_label = self._resolve_image_label(image_label) if image_label not in self._images: raise ValueError(f"Image label '{image_label}' not found. Please load an image first.") @@ -498,7 +498,7 @@ def remove_catalog(self, catalog_label: str | None = None) -> None: then all markers will be removed. """ if isinstance(catalog_label, list): - raise ValueError( + raise TypeError( "Cannot remove multiple catalogs from a list. Please specify " "a single catalog label or use '*' to remove all catalogs." ) diff --git a/src/astro_image_display_api/interface_definition.py b/src/astro_image_display_api/interface_definition.py index ae7e289..43e83fd 100644 --- a/src/astro_image_display_api/interface_definition.py +++ b/src/astro_image_display_api/interface_definition.py @@ -1,5 +1,6 @@ from typing import Protocol, runtime_checkable, Any from abc import abstractmethod +import numbers import os from astropy.coordinates import SkyCoord @@ -55,7 +56,7 @@ def load_image(self, data: Any, image_label: str | None = None) -> None: # Setting and getting image properties @abstractmethod - def set_cuts(self, cuts: tuple | BaseInterval, image_label: str | None = None) -> None: + def set_cuts(self, cuts: tuple[numbers.Real, numbers.Real] | BaseInterval, image_label: str | None = None) -> None: """ Set the cuts for the image. @@ -69,6 +70,16 @@ def set_cuts(self, cuts: tuple | BaseInterval, image_label: str | None = None) - image_label : str, optional The label of the image to set the cuts for. If not given and there is only one image loaded, the cuts for that image are set. + + Raises + ------ + TypeError + If the `cuts` parameter is not a tuple or an `astropy.visualization.BaseInterval` + object. + + ValueError + If the `image_label` is not provided when there are multiple images loaded, + or if the `image_label` does not correspond to a loaded image. """ raise NotImplementedError @@ -88,6 +99,12 @@ def get_cuts(self, image_label: str | None = None) -> BaseInterval: ------- cuts : `~astropy.visualization.BaseInterval` The Astropy interval object representing the current cuts. + + Raises + ------ + ValueError + If the `image_label` is not provided when there are multiple images loaded, + or if the `image_label` does not correspond to a loaded image. """ raise NotImplementedError @@ -103,8 +120,17 @@ def set_stretch(self, stretch: BaseStretch, image_label: str | None = None) -> N `~astropy.visualization.BaseStretch`. image_label : str, optional - The label of the image to set the cuts for. If not given and there is - only one image loaded, the cuts for that image are set. + The label of the image to set the stretch for. If not given and there is + only one image loaded, the stretch for that image are set. + + Raises + ------ + TypeError + If the `stretch` is not a valid `BaseStretch` object + + ValueError + if the `image_label` is not provided when there are multiple images loaded + or if the `image_label` does not correspond to a loaded image. """ raise NotImplementedError @@ -194,6 +220,11 @@ def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None: overwrite : bool, optional If `True`, overwrite the file if it exists. Default is `False`. + + Raises + ------ + FileExistsError + If the file already exists and `overwrite` is `False`. """ raise NotImplementedError @@ -229,6 +260,13 @@ def load_catalog(self, table: Table, x_colname: str = 'x', y_colname: str = 'y', A dictionary that specifies the style of the markers used to represent the catalog. See `ImageViewerInterface.set_catalog_style` for details. + + Raises + ------ + ValueError + If the `table` does not contain the required columns, or if + the `catalog_label` is not provided when there are multiple + catalogs loaded. """ raise NotImplementedError @@ -262,6 +300,13 @@ def set_catalog_style( ----- The following shapes are supported: "circle", "square", "crosshair", "plus", "diamond". + + Raises + ------ + ValueError + If there are multiple catalog styles set and the user has not + specified a `catalog_label` for which to set the style, or if + an style is set for a catalog that does not exist. """ raise NotImplementedError @@ -270,6 +315,15 @@ def get_catalog_style(self, catalog_label: str | None = None) -> dict: """ Get the style of the catalog markers. + Parameters + ---------- + catalog_label : str, optional + The name of the catalog. If not given and there is + only one catalog loaded, the style for that catalog is returned. + If there are multiple catalogs and no label is provided, an error + is raised. If the label does not correspond to a loaded + catalog, an empty dictionary is returned. + Returns ------- dict @@ -281,26 +335,39 @@ def get_catalog_style(self, catalog_label: str | None = None) -> dict: ValueError If there are multiple catalog styles set and the user has not specified a `catalog_label` for which to get the style. + """ raise NotImplementedError @abstractmethod - def remove_catalog(self, catalog_label: str | list[str] | None = None) -> None: + def remove_catalog(self, catalog_label: str | None = None) -> None: """ Remove markers from the image. Parameters ---------- catalog_label : str, optional - The name of the marker set to remove. If the value is ``"all"``, - then all markers will be removed. + The name of the catalog to remove. The value ``'*'`` can be used to + remove all catalogs. If not given and there is + only one catalog loaded, that catalog is removed. + + Raises + ------ + ValueError + If the `catalog_label` is not provided when there are multiple + catalogs loaded, or if the `catalog_label` does not correspond to a + loaded catalog. + + TypeError + If the `catalog_label` is not a string or `None`, or if it is not + one of the allowed values. """ raise NotImplementedError @abstractmethod def get_catalog(self, x_colname: str = 'x', y_colname: str = 'y', skycoord_colname: str = 'coord', - catalog_label: str | list[str] | None = None) -> Table: + catalog_label: str | None = None) -> Table: """ Get the marker positions. @@ -315,15 +382,19 @@ def get_catalog(self, x_colname: str = 'x', y_colname: str = 'y', skycoord_colname : str, optional The name of the column containing the sky coordinates. Default is ``'coord'``. - catalog_label : str or list of str, optional - The name of the marker set to use. If that value is ``"all"``, - then all markers will be returned. + catalog_label : str, optional + The name of the catalog set to get. Returns ------- table : `astropy.table.Table` The table containing the marker positions. If no markers match the ``catalog_label`` parameter, an empty table is returned. + + Raises + ------ + ValueError + If the `catalog_label` is not provided when there are multiple catalogs loaded. """ raise NotImplementedError @@ -408,6 +479,7 @@ def get_viewport(self, sky_or_pixel: str | None = None, image_label: str | None ------- ValueError If the `sky_or_pixel` parameter is not one of 'sky', 'pixel', or `None`, or if - the `image_label` is not provided when there are multiple images loaded. + the `image_label` is not provided when there are multiple images loaded, or if + the `image_label` does not correspond to a loaded image. """ raise NotImplementedError diff --git a/src/astro_image_display_api/widget_api_test.py b/src/astro_image_display_api/widget_api_test.py index 57484f2..0bb183f 100644 --- a/src/astro_image_display_api/widget_api_test.py +++ b/src/astro_image_display_api/widget_api_test.py @@ -616,7 +616,7 @@ def test_remove_catalog_does_not_accept_list(self): self.image.load_catalog(tab, catalog_label='test2', use_skycoord=False) with pytest.raises( - ValueError, + TypeError, match='Cannot remove multiple catalogs from a list' ): self.image.remove_catalog(catalog_label=['test1', 'test2']) @@ -646,7 +646,7 @@ def test_adding_catalog_as_world(self, data, wcs): def test_stretch(self): original_stretch = self.image.get_stretch() - with pytest.raises(ValueError, match=r'Stretch.*not valid.*'): + with pytest.raises(TypeError, match=r'Stretch.*not valid.*'): self.image.set_stretch('not a valid value') # A bad value should leave the stretch unchanged @@ -658,10 +658,10 @@ def test_stretch(self): assert isinstance(self.image.get_stretch(), LogStretch) def test_cuts(self, data): - with pytest.raises(ValueError, match='[mM]ust be'): + with pytest.raises(TypeError, match='[mM]ust be'): self.image.set_cuts('not a valid value') - with pytest.raises(ValueError, match='[mM]ust be'): + with pytest.raises(TypeError, match='[mM]ust be'): self.image.set_cuts((1, 10, 100)) # Setting using histogram requires data