11import logging
22import unittest .mock as mock
33from pathlib import Path
4+ from typing import Any , Optional
45
56import numpy as np
67import pytest
7- from fibad .data_sets .hsc_data_set import HSCDataSet
88from torchvision .transforms .v2 import CenterCrop , Lambda
99
10+ from fibad .data_sets .hsc_data_set import HSCDataSet
11+
1012test_dir = Path (__file__ ).parent / "test_data" / "dataloader"
1113
1214HSCDataSet ._called_from_test = True
@@ -27,8 +29,8 @@ class FakeFitsFS:
2729 more filesystem operations without a really good reason.
2830 """
2931
30- def __init__ (self , test_files : dict ):
31- self .patchers : list [mock ._patch [mock . Mock ]] = []
32+ def __init__ (self , test_files : dict , filter_catalog : Optional [ dict ] = None ):
33+ self .patchers : list [mock ._patch [Any ]] = []
3234
3335 self .test_files = test_files
3436
@@ -39,6 +41,12 @@ def __init__(self, test_files: dict):
3941 mock_fits_open = mock .Mock (side_effect = self ._open_file )
4042 self .patchers .append (mock .patch ("astropy.io.fits.open" , mock_fits_open ))
4143
44+ if filter_catalog is not None :
45+ mock_read_filter_catalog = mock .patch (
46+ "fibad.data_sets.hsc_data_set.HSCDataSet._read_filter_catalog" , lambda x , y : filter_catalog
47+ )
48+ self .patchers .append (mock_read_filter_catalog )
49+
4250 def _open_file (self , filename : Path , ** kwargs ) -> mock .Mock :
4351 shape = self .test_files [filename .name ]
4452 mock_open_ctx = mock .Mock ()
@@ -117,6 +125,24 @@ def generate_files(
117125 return test_files
118126
119127
128+ def generate_filter_catalog (test_files : dict ) -> dict :
129+ """Generates a filter catalog dict for use with FakeFitsFS from a filesystem dictionary
130+ created by generate_files.
131+
132+ This allows tests to alter the parsed filter_catalog, and interrogate what decisions HSCDataSet makes
133+ when a manifest or filter_catalog file contains corrupt information.
134+
135+ Returns
136+ -------
137+ dict
138+ Dictionary from ObjectID -> (Dicttionary from Filter -> Filename)
139+ """
140+
141+ # Use our initialization code to create a parsed files object
142+ with FakeFitsFS (test_files ):
143+ return HSCDataSet (mkconfig (crop_to = (99 , 99 ))).files
144+
145+
120146def test_load (caplog ):
121147 """Test to ensure loading a perfectly regular set of files works"""
122148 caplog .set_level (logging .WARNING )
@@ -313,6 +339,51 @@ def test_prune_size(caplog):
313339 assert "too small" in caplog .text
314340
315341
342+ def test_prune_filter_size_mismatch (caplog ):
343+ """Test to ensure images with different sizes per filter will be dropped"""
344+ caplog .set_level (logging .WARNING )
345+ test_files = {}
346+ test_files .update (generate_files (num_objects = 10 , num_filters = 5 , shape = (100 , 100 ), offset = 0 ))
347+ test_files ["00000000000000000_all_filters_HSC-R.fits" ] = (99 , 99 )
348+ print (test_files )
349+
350+ with FakeFitsFS (test_files ):
351+ a = HSCDataSet (mkconfig (crop_to = (99 , 99 )))
352+
353+ assert len (a ) == 9
354+ assert a .shape () == (5 , 99 , 99 )
355+
356+ # We should warn that we are dropping objects and the reason
357+ assert "Dropping object" in caplog .text
358+ assert "first filter" in caplog .text
359+
360+
361+ def test_prune_bad_filename (caplog ):
362+ """Test to ensure images with filenames set wrong will be dropped"""
363+ caplog .set_level (logging .WARNING )
364+ test_files = {}
365+ test_files .update (generate_files (num_objects = 10 , num_filters = 5 , shape = (100 , 100 ), offset = 0 ))
366+
367+ # Create a filter catalog with wrong file information.
368+ filter_catalog = generate_filter_catalog (test_files )
369+ filters = list (filter_catalog ["00000000000000000" ].keys ())
370+ filter_catalog ["00000000000000000" ][filters [0 ]] = filter_catalog ["00000000000000001" ][filters [0 ]]
371+
372+ with FakeFitsFS (test_files , filter_catalog ):
373+ # Initialize HSCDataset exercising the filter_catalog provided initialization pathway
374+ a = HSCDataSet (mkconfig (crop_to = (99 , 99 ), filter_catalog = "notarealfile.fits" ))
375+
376+ # Verify that the broken object has been dropped
377+ assert len (a ) == 9
378+
379+ # Verify the shape is correct.
380+ assert a .shape () == (5 , 99 , 99 )
381+
382+ # We should warn that we are dropping objects and the correct reason
383+ assert "Dropping object" in caplog .text
384+ assert "manifest is likely corrupt" in caplog .text
385+
386+
316387def test_partial_filter (caplog ):
317388 """Test to ensure when we only load some of the filters, only those filters end up in the dataset"""
318389 caplog .set_level (logging .WARNING )
0 commit comments