Skip to content

Commit 11b750a

Browse files
committed
zenodo stripe removal test with synthetic data
1 parent c7a985f commit 11b750a

File tree

3 files changed

+53
-3
lines changed

3 files changed

+53
-3
lines changed

.scripts/download_zenodo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ def calculate_md5(filename):
1818

1919
def download_zenodo_files(output_dir: Path):
2020
"""
21-
Download all files from Zenodo record 14938787 and verify their checksums.
21+
Download all files from Zenodo record 14979785 and verify their checksums.
2222
2323
Args:
2424
output_dir: Directory where files should be downloaded
2525
"""
2626
try:
27-
print("Fetching files from Zenodo record 14938787...")
27+
print("Fetching files from Zenodo record 14979785...")
2828
with urllib.request.urlopen(
29-
"https://zenodo.org/api/records/14938787"
29+
"https://zenodo.org/api/records/14979785"
3030
) as response:
3131
data = json.loads(response.read())
3232

zenodo-tests/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,20 @@ def geant4_dataset1(geant4_dataset1_file):
185185
)
186186

187187

188+
@pytest.fixture(scope="session")
189+
def synth_tomophantom1_file(test_data_path):
190+
in_file = os.path.join(test_data_path, "synth_tomophantom1.npz")
191+
return np.load(in_file)
192+
193+
194+
@pytest.fixture
195+
def synth_tomophantom1_dataset(synth_tomophantom1_file):
196+
return (
197+
cp.asarray(cp.swapaxes(synth_tomophantom1_file["projdata"], 0, 1)),
198+
synth_tomophantom1_file["angles"],
199+
)
200+
201+
188202
@pytest.fixture
189203
def ensure_clean_memory():
190204
gc.collect()

zenodo-tests/test_prep/test_stripe.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,42 @@ def test_remove_all_stripe_i12_dataset4(
134134
assert output.flags.c_contiguous
135135

136136

137+
@pytest.mark.parametrize(
138+
"dataset_fixture, snr_val, la_size_val, sm_size_val, norm_res_expected",
139+
[
140+
("synth_tomophantom1_dataset", 1.0, 61, 21, 53435.61),
141+
("synth_tomophantom1_dataset", 0.1, 61, 21, 67917.71),
142+
("synth_tomophantom1_dataset", 0.001, 61, 21, 70015.51),
143+
],
144+
ids=["snr_1", "snr_2", "snr_3"],
145+
)
146+
def test_remove_all_stripe_synth_tomophantom1_dataset(
147+
request, dataset_fixture, snr_val, la_size_val, sm_size_val, norm_res_expected
148+
):
149+
dataset = request.getfixturevalue(dataset_fixture)
150+
force_clean_gpu_memory()
151+
152+
output = remove_all_stripe(
153+
cp.copy(dataset[0]),
154+
snr=snr_val,
155+
la_size=la_size_val,
156+
sm_size=sm_size_val,
157+
dim=1,
158+
)
159+
np.savez(
160+
"/home/algol/Documents/DEV/httomolibgpu/zenodo-tests/large_data_archive/stripe_res2.npz",
161+
data=output.get(),
162+
)
163+
164+
residual_calc = dataset[0] - output
165+
norm_res = cp.linalg.norm(residual_calc.flatten())
166+
167+
assert isclose(norm_res, norm_res_expected, abs_tol=10**-2)
168+
assert not np.isnan(output).any(), "Output contains NaN values"
169+
assert output.dtype == np.float32
170+
assert output.flags.c_contiguous
171+
172+
137173
@pytest.mark.parametrize(
138174
"dataset_fixture, nvalue_val, vvalue_val, norm_res_expected",
139175
[

0 commit comments

Comments
 (0)