Skip to content

Commit 1eab1c8

Browse files
added tests to utils
1 parent 6c75ee1 commit 1eab1c8

File tree

1 file changed

+80
-0
lines changed

1 file changed

+80
-0
lines changed

tests/test_utils.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
import os
44
import numpy as np
55
import pyccl as ccl
6+
import sacc
67
from smokescreen.utils import string_to_seed, load_module_from_path
78
from smokescreen.utils import load_cosmology_from_partial_dict
9+
from smokescreen.utils import load_sacc_file
810

911

1012
def test_load_module_from_path():
@@ -69,3 +71,81 @@ def test_load_cosmology_from_partial_dict_invalid_type():
6971
cosmo_dict = {"Omega_c": "invalid", "sigma8": 0.8}
7072
with pytest.raises(TypeError):
7173
load_cosmology_from_partial_dict(cosmo_dict)
74+
75+
76+
class TestLoadSaccFile:
77+
"""Tests for the load_sacc_file utility function."""
78+
79+
def test_load_sacc_file_fits_format(self, tmp_path):
80+
# Create a FITS SACC file
81+
sacc_data = sacc.Sacc()
82+
sacc_data.add_tracer('misc', 'test')
83+
sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10)
84+
fits_path = tmp_path / "test.fits"
85+
sacc_data.save_fits(str(fits_path))
86+
87+
# Load using load_sacc_file
88+
loaded_sacc, file_format = load_sacc_file(str(fits_path))
89+
90+
assert isinstance(loaded_sacc, sacc.Sacc)
91+
assert file_format == 'fits'
92+
assert len(loaded_sacc.mean) == 1
93+
94+
def test_load_sacc_file_hdf5_format(self, tmp_path):
95+
# Create an HDF5 SACC file
96+
sacc_data = sacc.Sacc()
97+
sacc_data.add_tracer('misc', 'test')
98+
sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10)
99+
hdf5_path = tmp_path / "test.hdf5"
100+
sacc_data.save_hdf5(str(hdf5_path))
101+
102+
# Load using load_sacc_file
103+
loaded_sacc, file_format = load_sacc_file(str(hdf5_path))
104+
105+
assert isinstance(loaded_sacc, sacc.Sacc)
106+
assert file_format == 'hdf5'
107+
assert len(loaded_sacc.mean) == 1
108+
109+
def test_load_sacc_file_with_h5_extension(self, tmp_path):
110+
# Create an HDF5 SACC file with .h5 extension
111+
sacc_data = sacc.Sacc()
112+
sacc_data.add_tracer('misc', 'test')
113+
sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10)
114+
h5_path = tmp_path / "test.h5"
115+
sacc_data.save_hdf5(str(h5_path))
116+
117+
# Load using load_sacc_file - should detect as HDF5 regardless of extension
118+
loaded_sacc, file_format = load_sacc_file(str(h5_path))
119+
120+
assert isinstance(loaded_sacc, sacc.Sacc)
121+
assert file_format == 'hdf5'
122+
123+
def test_load_sacc_file_with_sacc_extension_hdf5(self, tmp_path):
124+
# Create an HDF5 SACC file with .sacc extension (like sn_datavector.sacc)
125+
sacc_data = sacc.Sacc()
126+
sacc_data.add_tracer('misc', 'test')
127+
sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10)
128+
sacc_path = tmp_path / "test.sacc"
129+
sacc_data.save_hdf5(str(sacc_path))
130+
131+
# Load using load_sacc_file - should detect as HDF5 even with .sacc extension
132+
loaded_sacc, file_format = load_sacc_file(str(sacc_path))
133+
134+
assert isinstance(loaded_sacc, sacc.Sacc)
135+
assert file_format == 'hdf5'
136+
137+
def test_load_sacc_file_nonexistent(self):
138+
# Test loading a nonexistent file
139+
with pytest.raises(ValueError) as exc_info:
140+
load_sacc_file("nonexistent_file.sacc")
141+
assert "Cannot load SACC file" in str(exc_info.value)
142+
143+
def test_load_sacc_file_invalid_fits(self, tmp_path):
144+
# Create an invalid FITS file (not a SACC file)
145+
invalid_path = tmp_path / "invalid.fits"
146+
with open(invalid_path, "w") as f:
147+
f.write("This is not a valid FITS file")
148+
149+
# Test that it raises ValueError
150+
with pytest.raises(ValueError):
151+
load_sacc_file(str(invalid_path))

0 commit comments

Comments
 (0)