Skip to content

Commit 3b33231

Browse files
final unit tests for hdf5
1 parent 9f23284 commit 3b33231

File tree

3 files changed

+77
-9
lines changed

3 files changed

+77
-9
lines changed

src/smokescreen/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def load_sacc_file(path_to_sacc: str) -> tuple[sacc.Sacc, str]:
157157
Returns
158158
-------
159159
sacc.Sacc
160-
Loaded SACC object
160+
Loaded SACC object with _smokescreen_input_format attribute set
161161
str
162162
Detected input format ('fits' or 'hdf5')
163163
@@ -169,13 +169,15 @@ def load_sacc_file(path_to_sacc: str) -> tuple[sacc.Sacc, str]:
169169
# Try HDF5 first (more specific format check)
170170
try:
171171
sacc_obj = sacc.Sacc.load_hdf5(path_to_sacc)
172+
sacc_obj._smokescreen_input_format = 'hdf5'
172173
return sacc_obj, 'hdf5'
173174
except Exception:
174175
pass
175176

176177
# Fall back to FITS format
177178
try:
178179
sacc_obj = sacc.Sacc.load_fits(path_to_sacc)
180+
sacc_obj._smokescreen_input_format = 'fits'
179181
return sacc_obj, 'fits'
180182
except Exception as e:
181183
raise ValueError(

tests/test_encryption.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import pytest
3+
from cryptography.fernet import Fernet
34
from smokescreen.encryption import encrypt_file, decrypt_file
45

56

@@ -120,3 +121,64 @@ def test_decrypt_file_nonexistent_key(encrypted_file_and_key):
120121
# Test decrypting with a nonexistent key
121122
with pytest.raises(FileNotFoundError):
122123
decrypt_file(str(encrypted_file_path), "nonexistent_key.key")
124+
125+
126+
def test_decrypt_file_save_with_fallback_handling(tmp_path):
127+
# Test the fallback handling in decrypt_file when filename doesn't end with .encrpt
128+
# Note: encrypt_file stores as basename.encrpt where basename is split('.')[0]
129+
# So data.sacc becomes data.encrpt, and decrypts back to data (original extension lost)
130+
original_content = b"test content for decryption fallback"
131+
original_file = tmp_path / "data.sacc"
132+
original_file.write_bytes(original_content)
133+
134+
# Encrypt it - basename is extracted as 'data' from 'data.sacc'
135+
encrypted_sacc, key = encrypt_file(str(original_file), path_to_save=str(tmp_path),
136+
save_file=True, keep_original=False)
137+
encrypted_file = tmp_path / "data.encrpt"
138+
key_file = tmp_path / "data.key"
139+
140+
assert encrypted_file.exists()
141+
assert key_file.exists()
142+
143+
# Decrypt with save_file=True
144+
decrypt_file(str(encrypted_file), str(key_file), save_file=True)
145+
146+
# The decrypted file restores basename (without .encrpt) which is 'data'
147+
# Note: original extension '.sacc' cannot be recovered after encryption
148+
decrypted_file = tmp_path / "data"
149+
assert decrypted_file.exists()
150+
151+
# Verify content matches original
152+
assert decrypted_file.read_bytes() == original_content
153+
154+
155+
def test_decrypt_file_save_with_fallback_extension_removal(tmp_path):
156+
# Test fallback handling when .encrpt appears in middle of filename
157+
# Note: encrypt_file stores as basename.encrpt where basename is split('.')[0]
158+
# So data.backup becomes data.encrpt, and decrypts back to data (original extension lost)
159+
160+
original_content = b"test content for edge case decryption"
161+
backup_content = b"backup content that was encrypted"
162+
163+
# Encrypt a backup file
164+
backup_file = tmp_path / "data.backup"
165+
backup_file.write_bytes(backup_content)
166+
167+
encrypted_sacc, key = encrypt_file(str(backup_file), path_to_save=str(tmp_path),
168+
save_file=True, keep_original=False)
169+
# The encrypted file is data.encrpt (basename extracted from 'data.backup')
170+
assert (tmp_path / "data.encrpt").exists()
171+
172+
# Get the key file path
173+
key_file = tmp_path / "data.key"
174+
175+
# Decrypt it using the key file
176+
decrypted_sacc = decrypt_file(str(tmp_path / "data.encrpt"), str(key_file), save_file=True)
177+
178+
# Should restore to basename (without .encrpt) which is 'data'
179+
# Note: original extension '.backup' cannot be recovered after encryption
180+
decrypted_file = tmp_path / "data"
181+
assert decrypted_file.exists()
182+
assert decrypted_file.read_bytes() == backup_content
183+
184+

tests/test_utils.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ class TestLoadSaccFile:
7777
"""Tests for the load_sacc_file utility function."""
7878

7979
def test_load_sacc_file_fits_format(self, tmp_path):
80+
import sacc as sacc_mod
8081
# Create a FITS SACC file
81-
sacc_data = sacc.Sacc()
82+
sacc_data = sacc_mod.Sacc()
8283
sacc_data.add_tracer('misc', 'test')
8384
sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10)
8485
fits_path = tmp_path / "test.fits"
@@ -87,13 +88,14 @@ def test_load_sacc_file_fits_format(self, tmp_path):
8788
# Load using load_sacc_file
8889
loaded_sacc, file_format = load_sacc_file(str(fits_path))
8990

90-
assert isinstance(loaded_sacc, sacc.Sacc)
91+
assert isinstance(loaded_sacc, sacc_mod.Sacc)
9192
assert file_format == 'fits'
9293
assert len(loaded_sacc.mean) == 1
9394

9495
def test_load_sacc_file_hdf5_format(self, tmp_path):
96+
import sacc as sacc_mod
9597
# Create an HDF5 SACC file
96-
sacc_data = sacc.Sacc()
98+
sacc_data = sacc_mod.Sacc()
9799
sacc_data.add_tracer('misc', 'test')
98100
sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10)
99101
hdf5_path = tmp_path / "test.hdf5"
@@ -102,13 +104,14 @@ def test_load_sacc_file_hdf5_format(self, tmp_path):
102104
# Load using load_sacc_file
103105
loaded_sacc, file_format = load_sacc_file(str(hdf5_path))
104106

105-
assert isinstance(loaded_sacc, sacc.Sacc)
107+
assert isinstance(loaded_sacc, sacc_mod.Sacc)
106108
assert file_format == 'hdf5'
107109
assert len(loaded_sacc.mean) == 1
108110

109111
def test_load_sacc_file_with_h5_extension(self, tmp_path):
112+
import sacc as sacc_mod
110113
# Create an HDF5 SACC file with .h5 extension
111-
sacc_data = sacc.Sacc()
114+
sacc_data = sacc_mod.Sacc()
112115
sacc_data.add_tracer('misc', 'test')
113116
sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10)
114117
h5_path = tmp_path / "test.h5"
@@ -117,12 +120,13 @@ def test_load_sacc_file_with_h5_extension(self, tmp_path):
117120
# Load using load_sacc_file - should detect as HDF5 regardless of extension
118121
loaded_sacc, file_format = load_sacc_file(str(h5_path))
119122

120-
assert isinstance(loaded_sacc, sacc.Sacc)
123+
assert isinstance(loaded_sacc, sacc_mod.Sacc)
121124
assert file_format == 'hdf5'
122125

123126
def test_load_sacc_file_with_sacc_extension_hdf5(self, tmp_path):
127+
import sacc as sacc_mod
124128
# Create an HDF5 SACC file with .sacc extension (like sn_datavector.sacc)
125-
sacc_data = sacc.Sacc()
129+
sacc_data = sacc_mod.Sacc()
126130
sacc_data.add_tracer('misc', 'test')
127131
sacc_data.add_data_point('galaxy_shear_cl_ee', ('test', 'test'), 1.0, ell=10)
128132
sacc_path = tmp_path / "test.sacc"
@@ -131,7 +135,7 @@ def test_load_sacc_file_with_sacc_extension_hdf5(self, tmp_path):
131135
# Load using load_sacc_file - should detect as HDF5 even with .sacc extension
132136
loaded_sacc, file_format = load_sacc_file(str(sacc_path))
133137

134-
assert isinstance(loaded_sacc, sacc.Sacc)
138+
assert isinstance(loaded_sacc, sacc_mod.Sacc)
135139
assert file_format == 'hdf5'
136140

137141
def test_load_sacc_file_nonexistent(self):

0 commit comments

Comments
 (0)