Skip to content

Commit 2da04b4

Browse files
committed
Break load test into several tests
1 parent d83d23c commit 2da04b4

File tree

1 file changed

+61
-47
lines changed

1 file changed

+61
-47
lines changed

src/astro_image_display_api/widget_api_test.py

Lines changed: 61 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class variable does the trick.
7070
def _assert_empty_catalog_table(self, table):
7171
assert isinstance(table, Table)
7272
assert len(table) == 0
73-
assert sorted(table.colnames) == sorted(['x', 'y', 'coord', 'marker name'])
73+
assert sorted(table.colnames) == sorted(['x', 'y', 'coord'])
7474

7575
def _get_catalog_names_as_set(self):
7676
marks = self.image.get_catalog_names()
@@ -190,84 +190,98 @@ def test_get_catalog_style_with_multiple_labels_raises_error(self, catalog):
190190
with pytest.raises(ValueError, match='Multiple catalog styles'):
191191
self.image.get_catalog_style()
192192

193-
def test_load_catalog(self):
194-
data = np.arange(10).reshape(5, 2)
195-
orig_tab = Table(data=data, names=['x', 'y'], dtype=('float', 'float'))
196-
tab = Table(data=data, names=['x', 'y'], dtype=('float', 'float'))
193+
@pytest.mark.parametrize("catalog_label", ['test1', None])
194+
def test_load_get_single_catalog_with_without_label(self, catalog, catalog_label):
195+
# Make sure we can get a single catalog with or without a label.
197196
self.image.load_catalog(
198-
tab,
197+
catalog,
199198
x_colname='x',
200199
y_colname='y',
201200
skycoord_colname='coord',
202-
catalog_label='test1',
201+
catalog_label=catalog_label,
203202
use_skycoord=False
204203
)
205204

205+
# Get the catalog without a label
206+
retrieved_catalog = self.image.get_catalog()
207+
assert (retrieved_catalog == catalog).all()
208+
209+
# Get the catalog with a label if there is one
210+
if catalog_label is not None:
211+
retrieved_catalog = self.image.get_catalog(catalog_label=catalog_label)
212+
assert (retrieved_catalog == catalog).all()
206213

207-
# Regression test for GitHub Issue 45:
208-
# Adding markers should not modify the input data table.
209-
assert (tab == orig_tab).all()
214+
def test_load_catalog_does_not_modify_input_catalog(self, catalog, data):
215+
# Adding a catalog should not modify the input data table.
216+
orig_tab = catalog.copy()
217+
self.image.load_catalog(catalog)
218+
_ = self.image.get_catalog()
219+
assert (catalog == orig_tab).all()
210220

211-
# Add more markers under different name.
221+
def test_load_multiple_catalogs(self, catalog):
222+
# Load and get mulitple catalogs
223+
# Add a catalog
212224
self.image.load_catalog(
213-
tab,
225+
catalog,
226+
x_colname='x',
227+
y_colname='y',
228+
catalog_label='test1',
229+
)
230+
# Add the catalog again under different name.
231+
self.image.load_catalog(
232+
catalog,
214233
x_colname='x',
215234
y_colname='y',
216-
skycoord_colname='coord',
217235
catalog_label='test2',
218-
use_skycoord=False
219236
)
220237

221-
marknames = self._get_catalog_names_as_set()
222-
assert marknames == set(['test1', 'test2'])
238+
assert sorted(self.image.get_catalog_names()) == ['test1', 'test2']
223239

224240
# No guarantee markers will come back in the same order, so sort them.
225241
t1 = self.image.get_catalog(catalog_label='test1')
226242
# Sort before comparing
227-
t1.sort('x')
228-
tab.sort('x')
229-
assert np.all(t1['x'] == tab['x'])
230-
assert (t1['y'] == tab['y']).all()
243+
t1.sort(['x', 'y'])
244+
catalog.sort(['x', 'y'])
245+
assert (t1['x'] == catalog['x']).all()
246+
assert (t1['y'] == catalog['y']).all()
231247

232248
t2 = self.image.get_catalog(catalog_label="test2")
233249
# Sort before comparing
234250
t2.sort(['x', 'y'])
235-
tab.sort(['x', 'y'])
236-
assert (t2['x'] == tab['x']).all()
237-
assert (t2['y'] == tab['y']).all()
251+
assert (t2['x'] == catalog['x']).all()
252+
assert (t2['y'] == catalog['y']).all()
238253

239-
self.image.remove_catalog(catalog_label='test1')
240-
marknames = self._get_catalog_names_as_set()
241-
assert marknames == set(['test2'])
254+
# get_catalog without a label should fail with multiple catalogs
255+
with pytest.raises(ValueError, match="Multiple catalog styles defined."):
256+
self.image.get_catalog()
242257

243-
# Add markers with no marker name and check we can retrieve them
244-
# using the default marker name
245-
self.image.load_catalog(
246-
tab,
247-
x_colname='x',
248-
y_colname='y',
249-
skycoord_colname='coord',
250-
use_skycoord=False
251-
)
252-
# Don't care about the order of the marker names so use set instead of
253-
# list.
254-
marknames = self._get_catalog_names_as_set()
255-
assert (set(marknames) == set(['test2']))
258+
# if we remove one of the catalogs we should be able to get the
259+
# other one without a label.
260+
self.image.remove_catalog(catalog_label='test1')
261+
# Make sure test1 is really gone.
262+
assert self.image.get_catalog_names() == ['test2']
256263

257-
# Clear markers to not pollute other tests.
258-
self.image.remove_catalog(catalog_label='*')
259-
marknames = self._get_catalog_names_as_set()
260-
assert len(marknames) == 0
261-
self._assert_empty_catalog_table(self.image.get_catalog())
262-
# Check that no markers remain after clearing
263-
tab = self.image.get_catalog()
264-
self._assert_empty_catalog_table(tab)
264+
# Get without a catalog
265+
t2 = self.image.get_catalog()
266+
# Sort before comparing
267+
t2.sort(['x', 'y'])
268+
assert (t2['x'] == catalog['x']).all()
269+
assert (t2['y'] == catalog['y']).all()
265270

266271
# Check that retrieving a marker set that doesn't exist returns
267272
# an empty table with the right columns
268273
tab = self.image.get_catalog(catalog_label='test1')
269274
self._assert_empty_catalog_table(tab)
270275

276+
def test_load_catalog_multiple_same_label(self, catalog):
277+
# Check that loading a catalog with the same label multiple times
278+
# does not raise an error and does not change the catalog.
279+
self.image.load_catalog(catalog, catalog_label='test1')
280+
self.image.load_catalog(catalog, catalog_label='test1')
281+
282+
retrieved_catalog = self.image.get_catalog(catalog_label='test1')
283+
assert len(retrieved_catalog) == 2 * len(catalog)
284+
271285
def test_load_catalog_with_skycoord_no_wcs(self, catalog, data):
272286
# Check that loading a catalog with skycoord but no x/y and
273287
# no WCS returns a catlog with None for x and y.

0 commit comments

Comments
 (0)