Skip to content

Commit 8dfda13

Browse files
eteqmwcraig
authored andcommitted
update tests for added methods
1 parent dc5899e commit 8dfda13

File tree

3 files changed

+35
-2
lines changed

3 files changed

+35
-2
lines changed

src/astro_image_display_api/api_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,3 +797,24 @@ def test_save_overwrite(self, tmp_path):
797797

798798
# Using overwrite should save successfully
799799
self.image.save(filename, overwrite=True)
800+
801+
def test_get_image_labels(self, data):
802+
# the test viewer begins with a default empty image
803+
assert len(self.image.get_image_labels()) == 1
804+
assert self.image.get_image_labels()[0] == None
805+
assert isinstance(self.image.get_image_labels(), tuple)
806+
807+
self.image.load_image(data, image_label="test")
808+
assert len(self.image.get_image_labels()) == 2
809+
assert self.image.get_image_labels()[-1] == 'test'
810+
811+
def test_get_image(self, data):
812+
self.image.load_image(data, image_label="test")
813+
814+
# currently the type is not specified in the API
815+
assert self.image.get_image() is not None
816+
assert self.image.get_image(image_label="test") is not None
817+
818+
with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
819+
self.image.get_image(image_label="not a valid label")
820+

src/astro_image_display_api/image_viewer_logic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,18 @@ def load_image(
332332
# working with the new image.
333333
self._wcs = self._images[image_label].wcs
334334

335+
336+
def get_image(self, image_label: str | None = None):
337+
image_label = self._resolve_image_label(image_label)
338+
if image_label not in self._images:
339+
raise ValueError(
340+
f"Image label '{image_label}' not found. Please load an image first."
341+
)
342+
return self._images[image_label]
343+
344+
def get_image_labels(self):
345+
return tuple(self._images.keys())
346+
335347
def _determine_largest_dimension(self, shape: tuple[int, int]) -> int:
336348
"""
337349
Determine which index is the largest dimension.

src/astro_image_display_api/interface_definition.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,13 @@ def get_image(
8080
@abstractmethod
8181
def get_image_labels(
8282
self,
83-
) -> list[str]:
83+
) -> tuple[str]:
8484
"""
8585
Get the labels of the loaded images.
8686
8787
Returns
8888
-------
89-
image_labels: list of str
89+
image_labels: tuple of str
9090
The labels of the loaded images.
9191
"""
9292
raise NotImplementedError

0 commit comments

Comments
 (0)