Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/astro_image_display_api/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,3 +797,28 @@ def test_save_overwrite(self, tmp_path):

# Using overwrite should save successfully
self.image.save(filename, overwrite=True)

def test_get_image_labels(self, data):
# the test viewer begins with a default empty image
assert len(self.image.get_image_labels()) == 1
assert self.image.get_image_labels()[0] is None
assert isinstance(self.image.get_image_labels(), tuple)

self.image.load_image(data, image_label="test")
assert len(self.image.get_image_labels()) == 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting; logically I would have thought of this as being 1.

Is this too complicated: None only counts as an image label if it has data associated with it

So the idea would be that if, like in this example, the only loaded data was for an image label you set, the get_image_labels returns just the image label you have used.

Doing this would return two image labels:

# Assume we are starting with a clean slate

assert self.image.get_image_labels() is None  # or maybe == []

# Load an image without a label, which means the label is `None`
self.image.load_image(data)

# Load a second image with an explicit label
self.image.load_image(data, label="test")

assert len(self.image.get_image_labels()) == 2  # None and "test"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tracking this in #74

assert self.image.get_image_labels()[-1] == "test"

def test_get_image(self, data):
self.image.load_image(data, image_label="test")

# currently the type is not specified in the API
assert self.image.get_image() is not None
assert self.image.get_image(image_label="test") is not None

retrieved_image = self.image.get_image(image_label="test")

self.image.load_image(retrieved_image, image_label="another test")
assert self.image.get_image(image_label="another test") is not None

with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
self.image.get_image(image_label="not a valid label")
15 changes: 15 additions & 0 deletions src/astro_image_display_api/image_viewer_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ViewportInfo:
stretch: BaseStretch | None = None
cuts: BaseInterval | tuple[numbers.Real, numbers.Real] | None = None
colormap: str | None = None
data: ArrayLike | NDData | CCDData | None = None


@dataclass
Expand Down Expand Up @@ -332,6 +333,17 @@ def load_image(
# working with the new image.
self._wcs = self._images[image_label].wcs

def get_image(self, image_label: str | None = None):
image_label = self._resolve_image_label(image_label)
if image_label not in self._images:
raise ValueError(
f"Image label '{image_label}' not found. Please load an image first."
)
return self._images[image_label].data

def get_image_labels(self):
return tuple(self._images.keys())

def _determine_largest_dimension(self, shape: tuple[int, int]) -> int:
"""
Determine which index is the largest dimension.
Expand Down Expand Up @@ -401,6 +413,7 @@ def _initialize_image_viewport_stretch_cuts(
def _load_fits(self, file: str | os.PathLike, image_label: str | None) -> None:
ccd = CCDData.read(file)
self._images[image_label].wcs = ccd.wcs
self._images[image_label].data = ccd
self._initialize_image_viewport_stretch_cuts(ccd.data, image_label)

def _load_array(self, array: ArrayLike, image_label: str | None) -> None:
Expand All @@ -416,6 +429,7 @@ def _load_array(self, array: ArrayLike, image_label: str | None) -> None:
self._images[image_label].largest_dimension = self._determine_largest_dimension(
array.shape
)
self._images[image_label].data = array
self._initialize_image_viewport_stretch_cuts(array, image_label)

def _load_nddata(self, data: NDData, image_label: str | None) -> None:
Expand All @@ -428,6 +442,7 @@ def _load_nddata(self, data: NDData, image_label: str | None) -> None:
The NDData object to load.
"""
self._images[image_label].wcs = data.wcs
self._images[image_label].data = data
self._images[image_label].largest_dimension = self._determine_largest_dimension(
data.data.shape
)
Expand Down
44 changes: 44 additions & 0 deletions src/astro_image_display_api/interface_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,50 @@ def load_image(self, data: Any, image_label: str | None = None) -> None:
raise NotImplementedError

# Setting and getting image properties
@abstractmethod
def get_image(
self,
image_label: str | None = None,
) -> Any:
"""
Parameters
----------
image_label : optional
The label of the image to set the cuts for. If not given and there is
only one image loaded, that image is returned.

Returns
-------
image_data : Any
The data of the loaded image. The exact type of the data is not specified,
and different backends may return different types. A return type compatible
with `astropy.nddata.NDData` is preferred, but not required. It is expected
that the returned data can be re-loaded into the viewer using
`load_image`, however.

Raises
------
ValueError
If the `image_label` is not provided when there are multiple images loaded,
or if the `image_label` does not correspond to a loaded image.

"""
raise NotImplementedError

@abstractmethod
def get_image_labels(
self,
) -> tuple[str]:
"""
Get the labels of the loaded images.

Returns
-------
image_labels: tuple of str
The labels of the loaded images.
"""
raise NotImplementedError

@abstractmethod
def set_cuts(
self,
Expand Down