From 558befb6258f61e6ae1af25ed0e28a424be7e7cf Mon Sep 17 00:00:00 2001 From: Matt Craig Date: Mon, 30 Jun 2025 10:47:21 -0700 Subject: [PATCH] Add test for and fix bug in viewer logic The bug was that the second time loda_cagatalog is called without a style argument the style gets set to None instead of the default style. --- src/astro_image_display_api/api_test.py | 15 +++++++++++++++ src/astro_image_display_api/image_viewer_logic.py | 5 +++++ 2 files changed, 20 insertions(+) diff --git a/src/astro_image_display_api/api_test.py b/src/astro_image_display_api/api_test.py index 589a0c2..63d3c1c 100644 --- a/src/astro_image_display_api/api_test.py +++ b/src/astro_image_display_api/api_test.py @@ -438,6 +438,21 @@ def test_set_get_catalog_style_preserves_extra_keywords(self, catalog): del retrieved_style["catalog_label"] # Remove the label assert retrieved_style == style + def test_catalog_has_style_after_loading(self, catalog): + # Check that loading a catalog sets a default style for that catalog. + self.image.load_catalog(catalog, catalog_label="test1") + + retrieved_style = self.image.get_catalog_style(catalog_label="test1") + assert isinstance(retrieved_style, dict) + assert "color" in retrieved_style + assert "shape" in retrieved_style + assert "size" in retrieved_style + + # Loading again should have the same style + self.image.load_catalog(catalog, catalog_label="test1") + retrieved_style2 = self.image.get_catalog_style(catalog_label="test1") + assert retrieved_style2 == retrieved_style + @pytest.mark.parametrize("catalog_label", ["test1", None]) def test_load_get_single_catalog_with_without_label(self, catalog, catalog_label): # Make sure we can get a single catalog with or without a label. diff --git a/src/astro_image_display_api/image_viewer_logic.py b/src/astro_image_display_api/image_viewer_logic.py index 4f830e0..af4f390 100644 --- a/src/astro_image_display_api/image_viewer_logic.py +++ b/src/astro_image_display_api/image_viewer_logic.py @@ -26,6 +26,7 @@ __all__ = ["ImageViewerLogic"] + @dataclass class CatalogInfo: """ @@ -523,7 +524,11 @@ def load_catalog( # Ensure a catalog always has a style if catalog_style is None: if not self._catalogs[catalog_label].style: + # No style has been set, so use the default style catalog_style = self._default_catalog_style.copy() + else: + # Use the existing style + catalog_style = self._catalogs[catalog_label].style.copy() self._catalogs[catalog_label].style = catalog_style