|
3 | 3 | import os |
4 | 4 | import numpy as np |
5 | 5 | import pyccl as ccl |
| 6 | +import sacc |
6 | 7 | from smokescreen.utils import string_to_seed, load_module_from_path |
7 | 8 | from smokescreen.utils import load_cosmology_from_partial_dict |
| 9 | +from smokescreen.utils import load_sacc_file |
8 | 10 |
|
9 | 11 |
|
10 | 12 | def test_load_module_from_path(): |
@@ -69,3 +71,81 @@ def test_load_cosmology_from_partial_dict_invalid_type(): |
69 | 71 | cosmo_dict = {"Omega_c": "invalid", "sigma8": 0.8} |
70 | 72 | with pytest.raises(TypeError): |
71 | 73 | load_cosmology_from_partial_dict(cosmo_dict) |
| 74 | + |
| 75 | + |
| 76 | +class TestLoadSaccFile: |
| 77 | + """Tests for the load_sacc_file utility function.""" |
| 78 | + |
| 79 | + def test_load_sacc_file_fits_format(self, tmp_path): |
| 80 | + # Create a FITS SACC file |
| 81 | + sacc_data = sacc.Sacc() |
| 82 | + sacc_data.add_tracer('misc', 'test') |
| 83 | + sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10) |
| 84 | + fits_path = tmp_path / "test.fits" |
| 85 | + sacc_data.save_fits(str(fits_path)) |
| 86 | + |
| 87 | + # Load using load_sacc_file |
| 88 | + loaded_sacc, file_format = load_sacc_file(str(fits_path)) |
| 89 | + |
| 90 | + assert isinstance(loaded_sacc, sacc.Sacc) |
| 91 | + assert file_format == 'fits' |
| 92 | + assert len(loaded_sacc.mean) == 1 |
| 93 | + |
| 94 | + def test_load_sacc_file_hdf5_format(self, tmp_path): |
| 95 | + # Create an HDF5 SACC file |
| 96 | + sacc_data = sacc.Sacc() |
| 97 | + sacc_data.add_tracer('misc', 'test') |
| 98 | + sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10) |
| 99 | + hdf5_path = tmp_path / "test.hdf5" |
| 100 | + sacc_data.save_hdf5(str(hdf5_path)) |
| 101 | + |
| 102 | + # Load using load_sacc_file |
| 103 | + loaded_sacc, file_format = load_sacc_file(str(hdf5_path)) |
| 104 | + |
| 105 | + assert isinstance(loaded_sacc, sacc.Sacc) |
| 106 | + assert file_format == 'hdf5' |
| 107 | + assert len(loaded_sacc.mean) == 1 |
| 108 | + |
| 109 | + def test_load_sacc_file_with_h5_extension(self, tmp_path): |
| 110 | + # Create an HDF5 SACC file with .h5 extension |
| 111 | + sacc_data = sacc.Sacc() |
| 112 | + sacc_data.add_tracer('misc', 'test') |
| 113 | + sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10) |
| 114 | + h5_path = tmp_path / "test.h5" |
| 115 | + sacc_data.save_hdf5(str(h5_path)) |
| 116 | + |
| 117 | + # Load using load_sacc_file - should detect as HDF5 regardless of extension |
| 118 | + loaded_sacc, file_format = load_sacc_file(str(h5_path)) |
| 119 | + |
| 120 | + assert isinstance(loaded_sacc, sacc.Sacc) |
| 121 | + assert file_format == 'hdf5' |
| 122 | + |
| 123 | + def test_load_sacc_file_with_sacc_extension_hdf5(self, tmp_path): |
| 124 | + # Create an HDF5 SACC file with .sacc extension (like sn_datavector.sacc) |
| 125 | + sacc_data = sacc.Sacc() |
| 126 | + sacc_data.add_tracer('misc', 'test') |
| 127 | + sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10) |
| 128 | + sacc_path = tmp_path / "test.sacc" |
| 129 | + sacc_data.save_hdf5(str(sacc_path)) |
| 130 | + |
| 131 | + # Load using load_sacc_file - should detect as HDF5 even with .sacc extension |
| 132 | + loaded_sacc, file_format = load_sacc_file(str(sacc_path)) |
| 133 | + |
| 134 | + assert isinstance(loaded_sacc, sacc.Sacc) |
| 135 | + assert file_format == 'hdf5' |
| 136 | + |
| 137 | + def test_load_sacc_file_nonexistent(self): |
| 138 | + # Test loading a nonexistent file |
| 139 | + with pytest.raises(ValueError) as exc_info: |
| 140 | + load_sacc_file("nonexistent_file.sacc") |
| 141 | + assert "Cannot load SACC file" in str(exc_info.value) |
| 142 | + |
| 143 | + def test_load_sacc_file_invalid_fits(self, tmp_path): |
| 144 | + # Create an invalid FITS file (not a SACC file) |
| 145 | + invalid_path = tmp_path / "invalid.fits" |
| 146 | + with open(invalid_path, "w") as f: |
| 147 | + f.write("This is not a valid FITS file") |
| 148 | + |
| 149 | + # Test that it raises ValueError |
| 150 | + with pytest.raises(ValueError): |
| 151 | + load_sacc_file(str(invalid_path)) |
0 commit comments