Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions src/astro_image_display_api/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,10 @@
)
from astropy.wcs import WCS

from .interface_definition import ImageViewerInterface

__all__ = ["ImageAPITest"]


class ImageAPITest:
cursor_error_classes = ValueError

@pytest.fixture
def data(self):
rng = np.random.default_rng(1234)
Expand Down Expand Up @@ -768,16 +764,6 @@ def test_set_colormap_errors(self, data):
with pytest.raises(ValueError, match="Multiple image labels defined"):
self.image.get_colormap()

def test_cursor(self):
assert self.image.cursor in self.image.ALLOWED_CURSOR_LOCATIONS
assert set(ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS) == set(
self.image.ALLOWED_CURSOR_LOCATIONS
)
with pytest.raises(self.cursor_error_classes):
self.image.cursor = "not a valid option"
self.image.cursor = "bottom"
assert self.image.cursor == "bottom"

def test_save(self, tmp_path):
filename = tmp_path / "woot.png"
self.image.save(filename)
Expand Down
18 changes: 0 additions & 18 deletions src/astro_image_display_api/image_viewer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,10 @@ class ImageViewerLogic:
image_width: int = 0
image_height: int = 0
zoom_level: float = 1
_cursor: str = ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS[0]
_cuts: BaseInterval | tuple[float, float] = AsymmetricPercentileInterval(
upper_percentile=95
)
_stretch: BaseStretch = LinearStretch
# viewer: Any

# Allowed locations for cursor display
ALLOWED_CURSOR_LOCATIONS: tuple = ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS

# some internal variable for keeping track of viewer state
_wcs: WCS | None = None
Expand Down Expand Up @@ -201,19 +196,6 @@ def get_colormap(self, image_label: str | None = None) -> str:

get_colormap.__doc__ = ImageViewerInterface.get_colormap.__doc__

@property
def cursor(self) -> str:
return self._cursor

@cursor.setter
def cursor(self, value: str) -> None:
if value not in self.ALLOWED_CURSOR_LOCATIONS:
raise ValueError(
f"Cursor location {value} is not valid. Must be one of "
f"{self.ALLOWED_CURSOR_LOCATIONS}."
)
self._cursor = value

# The methods, grouped loosely by purpose

def get_catalog_style(self, catalog_label=None) -> dict[str, Any]:
Expand Down
15 changes: 4 additions & 11 deletions src/astro_image_display_api/interface_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
from astropy.units import Quantity
from astropy.visualization import BaseInterval, BaseStretch

# Allowed locations for cursor display
ALLOWED_CURSOR_LOCATIONS = ("top", "bottom", None)

__all__ = [
"ImageViewerInterface",
]
Expand All @@ -23,10 +20,6 @@ class ImageViewerInterface(Protocol):
# do any checking at all of these types.
image_width: int
image_height: int
cursor: str

# Allowed locations for cursor display
ALLOWED_CURSOR_LOCATIONS: tuple = ALLOWED_CURSOR_LOCATIONS

# The methods, grouped loosely by purpose

Expand Down Expand Up @@ -441,10 +434,10 @@ def set_viewport(
The center of the viewport. If not given, the current center is used.
fov : `astropy.units.Quantity` or float, optional
The field of view (FOV) of the viewport. If not given, the current FOV
is used. If a float is given, it is interpreted as a size in pixels. For viewers
that are not square, the FOV is interpreted as the size of the shorter side
of the viewer such that the FOV is guaranteed to be entirely visible
regardless of the aspect ratio of the viewer.
is used. If a float is given, it is interpreted as a size in pixels. For
viewers that are not square, the FOV is interpreted as the size of the
shorter side of the viewer such that the FOV is guaranteed to be entirely
visible regardless of the aspect ratio of the viewer.
image_label : str, optional
The label of the image to set the viewport for. If not given and there is
only one image loaded, the viewport for that image is set. If there are
Expand Down
45 changes: 42 additions & 3 deletions tests/test_astro_image_display_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import re

from astropy.utils.data import get_pkg_data_contents

from astro_image_display_api import ImageViewerInterface


def test_api_test_class_completeness():
def test_api_test_class_covers_all_attributes_and_only_those_attributes():
"""
Test that the ImageWidgetAPITest class is complete and has tests
for all of the required methods and attributes.
for all of the required methods and attributes and does not access
any attributes or methods that are not part of the ImageViewerInterface.
"""
# Get the attributes on the protocol
required_attributes = ImageViewerInterface.__protocol_attrs__
Expand All @@ -15,11 +18,47 @@ def test_api_test_class_completeness():
api_test_content = get_pkg_data_contents(
"api_test.py", package="astro_image_display_api"
)

# This is the way the test class is accessed in the api_test.py file.
image_viewer_name = "self.image"

# Get all of the methods and attributes that are accessed
# in the api_test.py file.
# We use a regex to find all occurrences of the image_viewer_name
# followed by a dot and then an attribute name.
# This will match both attributes and methods.
attributes_accessed_in_test_class = re.findall(
rf"{image_viewer_name.replace(".", r"\.")}\.([a-zA-Z_][a-zA-Z0-9_]*)",
api_test_content,
)

# Get the attribute/method names as a set
attributes_accessed_in_test_class = list(set(attributes_accessed_in_test_class))

# Make sure that the test class does not access any attributes
# or methods that are not part of the ImageViewerInterface.
attr_in_test_class_is_in_interface = []
for attr in attributes_accessed_in_test_class:
attr_in_test_class_is_in_interface.append(attr in required_attributes)

attr_not_present_in_interface = [
attr
for attr, present in zip(
attributes_accessed_in_test_class,
attr_in_test_class_is_in_interface,
strict=True,
)
if not present
]

assert all(attr_in_test_class_is_in_interface), (
f"ImageWidgetAPITest accesses these attributes/methods that are not part of "
f"the ImageViewerInterface:\n{', '.join(attr_not_present_in_interface)}\n"
)
# Loop over the attributes and check that the test class has a method
# for each one whose name starts with test_ and ends with the attribute
# name.
attr_present = []
image_viewer_name = "self.image"
for attr in required_attributes:
attr_present.append(f"{image_viewer_name}.{attr}" in api_test_content)

Expand Down