@@ -70,7 +70,7 @@ class variable does the trick.
7070 def _assert_empty_catalog_table (self , table ):
7171 assert isinstance (table , Table )
7272 assert len (table ) == 0
73- assert sorted (table .colnames ) == sorted (['x' , 'y' , 'coord' , 'marker name' ])
73+ assert sorted (table .colnames ) == sorted (['x' , 'y' , 'coord' ])
7474
7575 def _get_catalog_names_as_set (self ):
7676 marks = self .image .get_catalog_names ()
@@ -190,84 +190,98 @@ def test_get_catalog_style_with_multiple_labels_raises_error(self, catalog):
190190 with pytest .raises (ValueError , match = 'Multiple catalog styles' ):
191191 self .image .get_catalog_style ()
192192
193- def test_load_catalog (self ):
194- data = np .arange (10 ).reshape (5 , 2 )
195- orig_tab = Table (data = data , names = ['x' , 'y' ], dtype = ('float' , 'float' ))
196- tab = Table (data = data , names = ['x' , 'y' ], dtype = ('float' , 'float' ))
193+ @pytest .mark .parametrize ("catalog_label" , ['test1' , None ])
194+ def test_load_get_single_catalog_with_without_label (self , catalog , catalog_label ):
195+ # Make sure we can get a single catalog with or without a label.
197196 self .image .load_catalog (
198- tab ,
197+ catalog ,
199198 x_colname = 'x' ,
200199 y_colname = 'y' ,
201200 skycoord_colname = 'coord' ,
202- catalog_label = 'test1' ,
201+ catalog_label = catalog_label ,
203202 use_skycoord = False
204203 )
205204
205+ # Get the catalog without a label
206+ retrieved_catalog = self .image .get_catalog ()
207+ assert (retrieved_catalog == catalog ).all ()
208+
209+ # Get the catalog with a label if there is one
210+ if catalog_label is not None :
211+ retrieved_catalog = self .image .get_catalog (catalog_label = catalog_label )
212+ assert (retrieved_catalog == catalog ).all ()
206213
207- # Regression test for GitHub Issue 45:
208- # Adding markers should not modify the input data table.
209- assert (tab == orig_tab ).all ()
214+ def test_load_catalog_does_not_modify_input_catalog (self , catalog , data ):
215+ # Adding a catalog should not modify the input data table.
216+ orig_tab = catalog .copy ()
217+ self .image .load_catalog (catalog )
218+ _ = self .image .get_catalog ()
219+ assert (catalog == orig_tab ).all ()
210220
211- # Add more markers under different name.
221+ def test_load_multiple_catalogs (self , catalog ):
222+ # Load and get mulitple catalogs
223+ # Add a catalog
212224 self .image .load_catalog (
213- tab ,
225+ catalog ,
226+ x_colname = 'x' ,
227+ y_colname = 'y' ,
228+ catalog_label = 'test1' ,
229+ )
230+ # Add the catalog again under different name.
231+ self .image .load_catalog (
232+ catalog ,
214233 x_colname = 'x' ,
215234 y_colname = 'y' ,
216- skycoord_colname = 'coord' ,
217235 catalog_label = 'test2' ,
218- use_skycoord = False
219236 )
220237
221- marknames = self ._get_catalog_names_as_set ()
222- assert marknames == set (['test1' , 'test2' ])
238+ assert sorted (self .image .get_catalog_names ()) == ['test1' , 'test2' ]
223239
224240 # No guarantee markers will come back in the same order, so sort them.
225241 t1 = self .image .get_catalog (catalog_label = 'test1' )
226242 # Sort before comparing
227- t1 .sort ('x' )
228- tab .sort ('x' )
229- assert np . all (t1 ['x' ] == tab ['x' ])
230- assert (t1 ['y' ] == tab ['y' ]).all ()
243+ t1 .sort ([ 'x' , 'y' ] )
244+ catalog .sort ([ 'x' , 'y' ] )
245+ assert (t1 ['x' ] == catalog ['x' ]). all ( )
246+ assert (t1 ['y' ] == catalog ['y' ]).all ()
231247
232248 t2 = self .image .get_catalog (catalog_label = "test2" )
233249 # Sort before comparing
234250 t2 .sort (['x' , 'y' ])
235- tab .sort (['x' , 'y' ])
236- assert (t2 ['x' ] == tab ['x' ]).all ()
237- assert (t2 ['y' ] == tab ['y' ]).all ()
251+ assert (t2 ['x' ] == catalog ['x' ]).all ()
252+ assert (t2 ['y' ] == catalog ['y' ]).all ()
238253
239- self . image . remove_catalog ( catalog_label = 'test1' )
240- marknames = self . _get_catalog_names_as_set ()
241- assert marknames == set ([ 'test2' ] )
254+ # get_catalog without a label should fail with multiple catalogs
255+ with pytest . raises ( ValueError , match = "Multiple catalog styles defined." ):
256+ self . image . get_catalog ( )
242257
243- # Add markers with no marker name and check we can retrieve them
244- # using the default marker name
245- self .image .load_catalog (
246- tab ,
247- x_colname = 'x' ,
248- y_colname = 'y' ,
249- skycoord_colname = 'coord' ,
250- use_skycoord = False
251- )
252- # Don't care about the order of the marker names so use set instead of
253- # list.
254- marknames = self ._get_catalog_names_as_set ()
255- assert (set (marknames ) == set (['test2' ]))
258+ # if we remove one of the catalogs we should be able to get the
259+ # other one without a label.
260+ self .image .remove_catalog (catalog_label = 'test1' )
261+ # Make sure test1 is really gone.
262+ assert self .image .get_catalog_names () == ['test2' ]
256263
257- # Clear markers to not pollute other tests.
258- self .image .remove_catalog (catalog_label = '*' )
259- marknames = self ._get_catalog_names_as_set ()
260- assert len (marknames ) == 0
261- self ._assert_empty_catalog_table (self .image .get_catalog ())
262- # Check that no markers remain after clearing
263- tab = self .image .get_catalog ()
264- self ._assert_empty_catalog_table (tab )
264+ # Get without a catalog
265+ t2 = self .image .get_catalog ()
266+ # Sort before comparing
267+ t2 .sort (['x' , 'y' ])
268+ assert (t2 ['x' ] == catalog ['x' ]).all ()
269+ assert (t2 ['y' ] == catalog ['y' ]).all ()
265270
266271 # Check that retrieving a marker set that doesn't exist returns
267272 # an empty table with the right columns
268273 tab = self .image .get_catalog (catalog_label = 'test1' )
269274 self ._assert_empty_catalog_table (tab )
270275
276+ def test_load_catalog_multiple_same_label (self , catalog ):
277+ # Check that loading a catalog with the same label multiple times
278+ # does not raise an error and does not change the catalog.
279+ self .image .load_catalog (catalog , catalog_label = 'test1' )
280+ self .image .load_catalog (catalog , catalog_label = 'test1' )
281+
282+ retrieved_catalog = self .image .get_catalog (catalog_label = 'test1' )
283+ assert len (retrieved_catalog ) == 2 * len (catalog )
284+
271285 def test_load_catalog_with_skycoord_no_wcs (self , catalog , data ):
272286 # Check that loading a catalog with skycoord but no x/y and
273287 # no WCS returns a catlog with None for x and y.
0 commit comments