Skip to content

Commit 826da47

Browse files
authored
Merge pull request #71 from eteq/get-image
Add get_image and get_image_label
2 parents 3ec9078 + b29ed57 commit 826da47

File tree

3 files changed

+105
-14
lines changed

3 files changed

+105
-14
lines changed

src/astro_image_display_api/api_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -797,3 +797,28 @@ def test_save_overwrite(self, tmp_path):
797797

798798
# Using overwrite should save successfully
799799
self.image.save(filename, overwrite=True)
800+
801+
def test_get_image_labels(self, data):
802+
# the test viewer begins with a default empty image
803+
assert len(self.image.get_image_labels()) == 1
804+
assert self.image.get_image_labels()[0] is None
805+
assert isinstance(self.image.get_image_labels(), tuple)
806+
807+
self.image.load_image(data, image_label="test")
808+
assert len(self.image.get_image_labels()) == 2
809+
assert self.image.get_image_labels()[-1] == "test"
810+
811+
def test_get_image(self, data):
812+
self.image.load_image(data, image_label="test")
813+
814+
# currently the type is not specified in the API
815+
assert self.image.get_image() is not None
816+
assert self.image.get_image(image_label="test") is not None
817+
818+
retrieved_image = self.image.get_image(image_label="test")
819+
820+
self.image.load_image(retrieved_image, image_label="another test")
821+
assert self.image.get_image(image_label="another test") is not None
822+
823+
with pytest.raises(ValueError, match="[Ii]mage label.*not found"):
824+
self.image.get_image(image_label="not a valid label")

src/astro_image_display_api/image_viewer_logic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class ViewportInfo:
5050
stretch: BaseStretch | None = None
5151
cuts: BaseInterval | tuple[numbers.Real, numbers.Real] | None = None
5252
colormap: str | None = None
53+
data: ArrayLike | NDData | CCDData | None = None
5354

5455

5556
@dataclass
@@ -332,6 +333,17 @@ def load_image(
332333
# working with the new image.
333334
self._wcs = self._images[image_label].wcs
334335

336+
def get_image(self, image_label: str | None = None):
337+
image_label = self._resolve_image_label(image_label)
338+
if image_label not in self._images:
339+
raise ValueError(
340+
f"Image label '{image_label}' not found. Please load an image first."
341+
)
342+
return self._images[image_label].data
343+
344+
def get_image_labels(self):
345+
return tuple(self._images.keys())
346+
335347
def _determine_largest_dimension(self, shape: tuple[int, int]) -> int:
336348
"""
337349
Determine which index is the largest dimension.
@@ -401,6 +413,7 @@ def _initialize_image_viewport_stretch_cuts(
401413
def _load_fits(self, file: str | os.PathLike, image_label: str | None) -> None:
402414
ccd = CCDData.read(file)
403415
self._images[image_label].wcs = ccd.wcs
416+
self._images[image_label].data = ccd
404417
self._initialize_image_viewport_stretch_cuts(ccd.data, image_label)
405418

406419
def _load_array(self, array: ArrayLike, image_label: str | None) -> None:
@@ -416,6 +429,7 @@ def _load_array(self, array: ArrayLike, image_label: str | None) -> None:
416429
self._images[image_label].largest_dimension = self._determine_largest_dimension(
417430
array.shape
418431
)
432+
self._images[image_label].data = array
419433
self._initialize_image_viewport_stretch_cuts(array, image_label)
420434

421435
def _load_nddata(self, data: NDData, image_label: str | None) -> None:
@@ -428,6 +442,7 @@ def _load_nddata(self, data: NDData, image_label: str | None) -> None:
428442
The NDData object to load.
429443
"""
430444
self._images[image_label].wcs = data.wcs
445+
self._images[image_label].data = data
431446
self._images[image_label].largest_dimension = self._determine_largest_dimension(
432447
data.data.shape
433448
)

src/astro_image_display_api/interface_definition.py

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,50 @@ def load_image(self, data: Any, image_label: str | None = None) -> None:
4848
raise NotImplementedError
4949

5050
# Setting and getting image properties
51+
@abstractmethod
52+
def get_image(
53+
self,
54+
image_label: str | None = None,
55+
) -> Any:
56+
"""
57+
Parameters
58+
----------
59+
image_label : optional
60+
The label of the image to set the cuts for. If not given and there is
61+
only one image loaded, that image is returned.
62+
63+
Returns
64+
-------
65+
image_data : Any
66+
The data of the loaded image. The exact type of the data is not specified,
67+
and different backends may return different types. A return type compatible
68+
with `astropy.nddata.NDData` is preferred, but not required. It is expected
69+
that the returned data can be re-loaded into the viewer using
70+
`load_image`, however.
71+
72+
Raises
73+
------
74+
ValueError
75+
If the ``image_label`` is not provided when there are multiple images
76+
loaded, or if the ``image_label`` does not correspond to a loaded image.
77+
78+
"""
79+
raise NotImplementedError
80+
81+
@abstractmethod
82+
def get_image_labels(
83+
self,
84+
) -> tuple[str]:
85+
"""
86+
Get the labels of the loaded images.
87+
88+
Returns
89+
-------
90+
image_labels: tuple of str
91+
The labels of the loaded images.
92+
"""
93+
raise NotImplementedError
94+
5195
@abstractmethod
5296
def set_cuts(
5397
self,
@@ -75,8 +119,8 @@ def set_cuts(
75119
`astropy.visualization.BaseInterval` object.
76120
77121
ValueError
78-
If the ``image_label`` is not provided when there are multiple images loaded,
79-
or if the ``image_label`` does not correspond to a loaded image.
122+
If the ``image_label`` is not provided when there are multiple images
123+
loaded, or if the ``image_label`` does not correspond to a loaded image.
80124
81125
Notes
82126
-----
@@ -105,8 +149,8 @@ def get_cuts(self, image_label: str | None = None) -> BaseInterval:
105149
Raises
106150
------
107151
ValueError
108-
If the ``image_label`` is not provided when there are multiple images loaded,
109-
or if the ``image_label`` does not correspond to a loaded image.
152+
If the ``image_label`` is not provided when there are multiple images
153+
loaded, or if the ``image_label`` does not correspond to a loaded image.
110154
111155
Notes
112156
-----
@@ -132,7 +176,8 @@ def set_stretch(self, stretch: BaseStretch, image_label: str | None = None) -> N
132176
Raises
133177
------
134178
TypeError
135-
If the ``stretch`` is not a valid `~astropy.visualization.BaseStretch` object.
179+
If the ``stretch`` is not a valid `~astropy.visualization.BaseStretch`
180+
object.
136181
137182
ValueError
138183
If the ``image_label`` is not provided when there are multiple images loaded
@@ -291,7 +336,8 @@ def load_catalog(
291336
name will be generated.
292337
catalog_style : dict, optional
293338
A dictionary that specifies the style of the markers used to
294-
represent the catalog. See `~astro_image_display_api.interface_definition.ImageViewerInterface.set_catalog_style`
339+
represent the catalog. See
340+
`~astro_image_display_api.interface_definition.ImageViewerInterface.set_catalog_style`
295341
for details.
296342
297343
Raises
@@ -497,15 +543,17 @@ def set_viewport(
497543
Raises
498544
------
499545
TypeError
500-
If the ``center`` is not a `~astropy.coordinates.SkyCoord` object or a tuple of floats, or if
501-
the ``fov`` is not a angular `~astropy.units.Quantity` or a float, or if there is no WCS
502-
and the center or field of view require a WCS to be applied.
546+
If the ``center`` is not a `~astropy.coordinates.SkyCoord` object or a tuple
547+
of floats, or if the ``fov`` is not a angular `~astropy.units.Quantity` or a
548+
float, or if there is no WCS and the center or field of view require a WCS
549+
to be applied.
503550
504551
ValueError
505552
If ``image_label`` is not provided when there are multiple images loaded.
506553
507554
`astropy.units.UnitTypeError`
508-
If the ``fov`` is a `~astropy.units.Quantity` but does not have an angular unit.
555+
If the ``fov`` is a `~astropy.units.Quantity` but does not have an angular
556+
unit.
509557
510558
Notes
511559
-----
@@ -524,9 +572,11 @@ def get_viewport(
524572
Parameters
525573
----------
526574
sky_or_pixel : str, optional
527-
If 'sky', the center will be returned as a `~astropy.coordinates.SkyCoord` object.
528-
If 'pixel', the center will be returned as a tuple of pixel coordinates.
529-
If `None`, the default behavior is to return the center as a `~astropy.coordinates.SkyCoord` if
575+
If 'sky', the center will be returned as a `~astropy.coordinates.SkyCoord`
576+
object. If 'pixel', the center will be returned as a tuple of pixel
577+
coordinates.
578+
If `None`, the default behavior is to return the center as a
579+
`~astropy.coordinates.SkyCoord` if
530580
possible, or as a tuple of floats if the image is in pixel coordinates and
531581
has no WCS information.
532582
image_label : str, optional
@@ -539,7 +589,8 @@ def get_viewport(
539589
dict
540590
A dictionary containing the current viewport settings.
541591
The keys are 'center', 'fov', and 'image_label'.
542-
- 'center' is an `~astropy.coordinates.SkyCoord` object or a tuple of floats.
592+
- 'center' is an `~astropy.coordinates.SkyCoord` object or a tuple of
593+
floats.
543594
- 'fov' is an `~astropy.units.Quantity` object or a float.
544595
- 'image_label' is a string representing the label of the image.
545596

0 commit comments

Comments
 (0)