Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions src/astro_image_display_api/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

__all__ = ["ImageAPITest"]

DEFAULT_IMAGE_SHAPE = (100, 150)


class ImageAPITest:
@pytest.fixture
def data(self):
rng = np.random.default_rng(1234)
return rng.random((100, 150))
return rng.random(DEFAULT_IMAGE_SHAPE)

@pytest.fixture
def wcs(self):
Expand All @@ -49,8 +51,8 @@ def catalog(self, wcs: WCS) -> Table:
expected columns.
"""
rng = np.random.default_rng(45328975)
x = rng.uniform(0, self.image.image_width, size=10)
y = rng.uniform(0, self.image.image_height, size=10)
x = rng.uniform(0, DEFAULT_IMAGE_SHAPE[0], size=10)
y = rng.uniform(0, DEFAULT_IMAGE_SHAPE[1], size=10)
coord = wcs.pixel_to_world(x, y)

cat = Table(
Expand All @@ -70,7 +72,7 @@ def setup(self):
Subclasses MUST define ``image_widget_class`` -- doing so as a
class variable does the trick.
"""
self.image = self.image_widget_class(image_width=250, image_height=100)
self.image = self.image_widget_class()

def _assert_empty_catalog_table(self, table):
assert isinstance(table, Table)
Expand All @@ -81,17 +83,6 @@ def _get_catalog_names_as_set(self):
marks = self.image.catalog_names
return set(marks)

def test_width_height(self):
assert self.image.image_width == 250
assert self.image.image_height == 100

width = 200
height = 300
self.image.image_width = width
self.image.image_height = height
assert self.image.image_width == width
assert self.image.image_height == height

@pytest.mark.parametrize("load_type", ["fits", "nddata", "array"])
def test_load(self, data, tmp_path, load_type):
match load_type:
Expand Down Expand Up @@ -915,3 +906,25 @@ def test_all_methods_accept_additional_kwargs(self, data, catalog, tmp_path):
"The following methods failed when called with additional kwargs:\n\t"
f"{'\n\t'.join(failed_methods)}"
)

def test_every_method_attribute_has_docstring(self):
"""
Check that every method and attribute in the protocol has a docstring.
"""
from astro_image_display_api import ImageViewerInterface

all_methods_and_attributes = ImageViewerInterface.__protocol_attrs__

method_attrs_no_docs = []

for method in all_methods_and_attributes:
attr = getattr(self.image, method)
# Make list of methods and attributes that have no docstring
# and assert that list is empty at the end of the test.
if not attr.__doc__:
method_attrs_no_docs.append(method)

assert not method_attrs_no_docs, (
"The following methods and attributes have no docstring:\n\t"
f"{'\n\t'.join(method_attrs_no_docs)}"
)
90 changes: 14 additions & 76 deletions src/astro_image_display_api/image_viewer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,27 @@ class ViewportInfo:
data: ArrayLike | NDData | CCDData | None = None


def docs_from_interface(cls):
"""
Decorator to copy the docstrings from the interface methods to the
methods in the class.
"""
for name, method in cls.__dict__.items():
if not name.startswith("_"):
interface_method = getattr(ImageViewerInterface, name, None)
if interface_method:
method.__doc__ = interface_method.__doc__
return cls


@dataclass
@docs_from_interface
class ImageViewerLogic:
"""
This viewer does not do anything except making changes to its internal
state to simulate the behavior of a real viewer.
"""

# These are attributes, not methods. The type annotations are there
# to make sure Protocol knows they are attributes. Python does not
# do any checking at all of these types.
image_width: int = 0
image_height: int = 0
zoom_level: float = 1
_cuts: BaseInterval | tuple[float, float] = AsymmetricPercentileInterval(
upper_percentile=95
)
Expand Down Expand Up @@ -206,8 +214,6 @@ def set_colormap(
)
self._images[image_label].colormap = map_name

set_colormap.__doc__ = ImageViewerInterface.set_colormap.__doc__

def get_colormap(
self,
image_label: str | None = None,
Expand All @@ -220,28 +226,13 @@ def get_colormap(
)
return self._images[image_label].colormap

get_colormap.__doc__ = ImageViewerInterface.get_colormap.__doc__

# The methods, grouped loosely by purpose

def get_catalog_style(
self,
catalog_label=None,
**kwargs, # noqa: ARG002
) -> dict[str, Any]:
"""
Get the style for the catalog.

Parameters
----------
catalog_label : str, optional
The label of the catalog. Default is ``None``.

Returns
-------
dict
The style for the catalog.
"""
catalog_label = self._resolve_catalog_label(catalog_label)

style = self._catalogs[catalog_label].style.copy()
Expand All @@ -256,22 +247,6 @@ def set_catalog_style(
size: float = 5,
**kwargs,
) -> None:
"""
Set the style for the catalog.

Parameters
----------
catalog_label : str, optional
The label of the catalog.
shape : str, optional
The shape of the markers.
color : str, optional
The color of the markers.
size : float, optional
The size of the markers.
**kwargs
Additional keyword arguments to pass to the marker style.
"""
catalog_label = self._resolve_catalog_label(catalog_label)

if self._catalogs[catalog_label].data is None:
Expand Down Expand Up @@ -324,18 +299,6 @@ def load_image(
image_label: str | None = None,
**kwargs, # noqa: ARG002
) -> None:
"""
Load a FITS file into the viewer.

Parameters
----------
file : str or array-like
The FITS file to load. If a string, it can be a URL or a
file path.

image_label : str, optional
A label for the image.
"""
image_label = self._resolve_image_label(image_label)

# Delete the current viewport if it exists
Expand Down Expand Up @@ -375,8 +338,6 @@ def get_image(
def image_labels(self) -> tuple[str, ...]:
return tuple(k for k in self._images.keys() if k is not None)

image_labels.__doc__ = ImageViewerInterface.image_labels.__doc__

def _determine_largest_dimension(self, shape: tuple[int, int]) -> int:
"""
Determine which index is the largest dimension.
Expand Down Expand Up @@ -497,19 +458,6 @@ def save(
overwrite: bool = False,
**kwargs, # noqa: ARG002
) -> None:
"""
Save the current view to a file.

Parameters
----------
filename : str or `os.PathLike`
The file to save to. The format is determined by the
extension.

overwrite : bool, optional
If `True`, overwrite the file if it exists. Default is
`False`.
"""
p = Path(filename)
if p.exists() and not overwrite:
raise FileExistsError(
Expand Down Expand Up @@ -586,8 +534,6 @@ def load_catalog(

self._catalogs[catalog_label].style = catalog_style

load_catalog.__doc__ = ImageViewerInterface.load_catalog.__doc__

def remove_catalog(
self,
catalog_label: str | None = None,
Expand Down Expand Up @@ -645,14 +591,10 @@ def get_catalog(

return result

get_catalog.__doc__ = ImageViewerInterface.get_catalog.__doc__

@property
def catalog_names(self) -> tuple[str, ...]:
return tuple(self._user_catalog_labels())

catalog_names.__doc__ = ImageViewerInterface.catalog_names.__doc__

# Methods that modify the view
def set_viewport(
self,
Expand Down Expand Up @@ -728,8 +670,6 @@ def set_viewport(
self._images[image_label].center = center
self._images[image_label].fov = fov

set_viewport.__doc__ = ImageViewerInterface.set_viewport.__doc__

def get_viewport(
self,
sky_or_pixel: str | None = None,
Expand Down Expand Up @@ -805,5 +745,3 @@ def get_viewport(
fov = viewport.fov

return dict(center=center, fov=fov, image_label=image_label)

get_viewport.__doc__ = ImageViewerInterface.get_viewport.__doc__
6 changes: 0 additions & 6 deletions src/astro_image_display_api/interface_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@

@runtime_checkable
class ImageViewerInterface(Protocol):
# These are attributes, not methods. The type annotations are there
# to make sure Protocol knows they are attributes. Python does not
# do any checking at all of these types.
image_width: int
image_height: int

# The methods, grouped loosely by purpose

# Method for loading image data
Expand Down
23 changes: 23 additions & 0 deletions tests/test_astro_image_display_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,26 @@ def test_api_test_class_covers_all_attributes_and_only_those_attributes():
"ImageWidgetAPITest does not access these "
f"attributes/methods:\n{missing_attributes_msg}\n"
)


def test_every_method_attribute_has_docstring():
Copy link

Copilot AI Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] This test is duplicated in api_test.py; consider extracting a shared helper or fixture to avoid copy–paste.

Copilot uses AI. Check for mistakes.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is deliberate. One is used internally, one is meant to be used externally.

"""
Check that every method and attribute in the protocol has a docstring.
"""
from astro_image_display_api import ImageViewerInterface

all_methods_and_attributes = ImageViewerInterface.__protocol_attrs__

method_attrs_no_docs = []

for method in all_methods_and_attributes:
attr = getattr(ImageViewerInterface, method)
# Make list of methods and attributes that have no docstring
# and assert that list is empty at the end of the test.
if not attr.__doc__:
method_attrs_no_docs.append(method)

assert not method_attrs_no_docs, (
"The following methods and attributes have no docstring:\n\t"
f"{'\n\t'.join(method_attrs_no_docs)}"
)