Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions src/astro_image_display_api/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,3 +822,97 @@ def test_get_image(self, data):

with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
self.image.get_image(image_label="not a valid label")

def test_all_methods_accept_additional_kwargs(self, data, catalog, tmp_path):
"""
Make sure all methods accept additional keyword arguments
that are not defined in the protocol.
"""
from astro_image_display_api import ImageViewerInterface

all_methods_and_attributes = ImageViewerInterface.__protocol_attrs__

all_methods = [
method
for method in all_methods_and_attributes
if callable(getattr(self.image, method))
]

# Make a small dictionary keys that are random characters
additional_kwargs = {k: f"value{k}" for k in ["fsda", "urioeh", "m898h]"]}

# Make a dictionary of the required arguments for any methods that have required
# argument
required_args = dict(
load_image=data,
set_cuts=(10, 100),
set_stretch=LogStretch(),
set_colormap="viridis",
save=tmp_path / "test.png",
load_catalog=catalog,
)

failed_methods = []

# Take out the loading methods because they must happen first and take out
# remove_catalog because it must happen last.
all_methods = list(
set(all_methods) - set(["load_image", "load_catalog", "remove_catalog"])
)

# Load an image and a catalog first since other methods require these
# have been done
try:
self.image.load_image(required_args["load_image"], **additional_kwargs)
except TypeError as e:
if "required positional argument" not in str(e):
# If the error is not about a missing required argument, we
# consider it a failure.
failed_methods.append("load_image")
else:
raise e

try:
self.image.load_catalog(required_args["load_catalog"], **additional_kwargs)
except TypeError as e:
if "required positional argument" not in str(e):
# If the error is not about a missing required argument, we
# consider it a failure.
failed_methods.append("load_catalog")
else:
raise e

if not failed_methods:
# No point in running some of these if setting image or catalog has failed.
# Run remove_catalog last so that it does not interfere with the
# other methods that require an image or catalog to be loaded.
for method in all_methods + ["remove_catalog"]:
# Call each method with the required arguments and additional kwargs
# Accumulate the failures and report them at the end
try:
if method in required_args:
# If the method has required arguments, call it with those
getattr(self.image, method)(
required_args[method], **additional_kwargs
)
else:
# If the method does not have required arguments, just call it
# with additional kwargs
getattr(self.image, method)(**additional_kwargs)
except TypeError as e:
if "required positional argument" not in str(e):
# If the error is not about a missing required argument, we
# consider it a failure.
failed_methods.append(method)
else:
raise e

else:
failed_methods.append(
"No other methods were tested because the ones above failed."
)

assert not failed_methods, (
"The following methods failed when called with additional kwargs:\n\t"
f"{'\n\t'.join(failed_methods)}"
)
75 changes: 63 additions & 12 deletions src/astro_image_display_api/image_viewer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,24 @@ def _default_catalog_style(self) -> dict[str, Any]:
"size": 5,
}

def get_stretch(self, image_label: str | None = None) -> BaseStretch:
def get_stretch(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do get stuff really need **kwargs? I thought we want flexibility to set things but get should be straightforward? Am I missing something?

Same comment on all the other get methods.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered the same thing....definitely get_catalog_names and get_image_labels will get turned into properties without an arguments once this is merged.

self,
image_label: str | None = None,
**kwargs, # noqa: ARG002
) -> BaseStretch:
image_label = self._resolve_image_label(image_label)
if image_label not in self._images:
raise ValueError(
f"Image label '{image_label}' not found. Please load an image first."
)
return self._images[image_label].stretch

def set_stretch(self, value: BaseStretch, image_label: str | None = None) -> None:
def set_stretch(
self,
value: BaseStretch,
image_label: str | None = None,
**kwargs, # noqa: ARG002
) -> None:
if not isinstance(value, BaseStretch):
raise TypeError(
f"Stretch option {value} is not valid. Must be an "
Expand All @@ -150,7 +159,11 @@ def set_stretch(self, value: BaseStretch, image_label: str | None = None) -> Non
)
self._images[image_label].stretch = value

def get_cuts(self, image_label: str | None = None) -> tuple:
def get_cuts(
self,
image_label: str | None = None,
**kwargs, # noqa: ARG002
) -> tuple:
image_label = self._resolve_image_label(image_label)
if image_label not in self._images:
raise ValueError(
Expand All @@ -162,6 +175,7 @@ def set_cuts(
self,
value: tuple[numbers.Real, numbers.Real] | BaseInterval,
image_label: str | None = None,
**kwargs, # noqa: ARG002
) -> None:
if isinstance(value, tuple) and len(value) == 2:
self._cuts = ManualInterval(value[0], value[1])
Expand All @@ -179,7 +193,12 @@ def set_cuts(
)
self._images[image_label].cuts = self._cuts

def set_colormap(self, map_name: str, image_label: str | None = None) -> None:
def set_colormap(
self,
map_name: str,
image_label: str | None = None,
**kwargs, # noqa: ARG002
) -> None:
image_label = self._resolve_image_label(image_label)
if image_label not in self._images:
raise ValueError(
Expand All @@ -189,7 +208,11 @@ def set_colormap(self, map_name: str, image_label: str | None = None) -> None:

set_colormap.__doc__ = ImageViewerInterface.set_colormap.__doc__

def get_colormap(self, image_label: str | None = None) -> str:
def get_colormap(
self,
image_label: str | None = None,
**kwargs, # noqa: ARG002
) -> str:
image_label = self._resolve_image_label(image_label)
if image_label not in self._images:
raise ValueError(
Expand All @@ -201,7 +224,11 @@ def get_colormap(self, image_label: str | None = None) -> str:

# The methods, grouped loosely by purpose

def get_catalog_style(self, catalog_label=None) -> dict[str, Any]:
def get_catalog_style(
self,
catalog_label=None,
**kwargs, # noqa: ARG002
) -> dict[str, Any]:
"""
Get the style for the catalog.

Expand Down Expand Up @@ -295,6 +322,7 @@ def load_image(
self,
file: str | os.PathLike | ArrayLike | NDData,
image_label: str | None = None,
**kwargs, # noqa: ARG002
) -> None:
"""
Load a FITS file into the viewer.
Expand Down Expand Up @@ -333,15 +361,20 @@ def load_image(
# working with the new image.
self._wcs = self._images[image_label].wcs

def get_image(self, image_label: str | None = None):
def get_image(
self, image_label: str | None = None, **kwargs # noqa: ARG002
) -> ArrayLike | NDData | CCDData:
image_label = self._resolve_image_label(image_label)
if image_label not in self._images:
raise ValueError(
f"Image label '{image_label}' not found. Please load an image first."
)
return self._images[image_label].data

def get_image_labels(self):
def get_image_labels(
self,
**kwargs, # noqa: ARG002
) -> tuple[str, ...]:
return tuple(self._images.keys())

def _determine_largest_dimension(self, shape: tuple[int, int]) -> int:
Expand Down Expand Up @@ -458,7 +491,12 @@ def _load_asdf(self, asdf_file: str | os.PathLike, image_label: str | None) -> N
)

# Saving contents of the view and accessing the view
def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:
def save(
self,
filename: str | os.PathLike,
overwrite: bool = False,
**kwargs, # noqa: ARG002
) -> None:
"""
Save the current view to a file.

Expand Down Expand Up @@ -490,6 +528,7 @@ def load_catalog(
use_skycoord: bool = False,
catalog_label: str | None = None,
catalog_style: dict | None = None,
**kwargs, # noqa: ARG002
) -> None:
try:
coords = table[skycoord_colname]
Expand Down Expand Up @@ -549,7 +588,11 @@ def load_catalog(

load_catalog.__doc__ = ImageViewerInterface.load_catalog.__doc__

def remove_catalog(self, catalog_label: str | None = None) -> None:
def remove_catalog(
self,
catalog_label: str | None = None,
**kwargs, # noqa: ARG002
) -> None:
"""
Remove markers from the image.

Expand Down Expand Up @@ -584,6 +627,7 @@ def get_catalog(
y_colname: str = "y",
skycoord_colname: str = "coord",
catalog_label: str | None = None,
**kwargs, # noqa: ARG002
) -> Table:
# Dostring is copied from the interface definition, so it is not
# duplicated here.
Expand All @@ -603,7 +647,10 @@ def get_catalog(

get_catalog.__doc__ = ImageViewerInterface.get_catalog.__doc__

def get_catalog_names(self) -> list[str]:
def get_catalog_names(
self,
**kwargs, # noqa: ARG002
) -> list[str]:
return list(self._user_catalog_labels())

get_catalog_names.__doc__ = ImageViewerInterface.get_catalog_names.__doc__
Expand All @@ -614,6 +661,7 @@ def set_viewport(
center: SkyCoord | tuple[numbers.Real, numbers.Real] | None = None,
fov: Quantity | numbers.Real | None = None,
image_label: str | None = None,
**kwargs, # noqa: ARG002
) -> None:
image_label = self._resolve_image_label(image_label)

Expand Down Expand Up @@ -685,7 +733,10 @@ def set_viewport(
set_viewport.__doc__ = ImageViewerInterface.set_viewport.__doc__

def get_viewport(
self, sky_or_pixel: str | None = None, image_label: str | None = None
self,
sky_or_pixel: str | None = None,
image_label: str | None = None,
**kwargs, # noqa: ARG002
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That said, maybe we do want **kwargs here since viewer in Jdaviz has reference name as such, that might not apply to other backends.

) -> dict[str, Any]:
if sky_or_pixel not in (None, "sky", "pixel"):
raise ValueError("sky_or_pixel must be 'sky', 'pixel', or None.")
Expand Down
Loading