Skip to content

Commit 973647c

Browse files
Merge pull request #381 from MannLabs/fix_364
improve robustness of remote dataset retrieval
2 parents 7e816cf + ada2eef commit 973647c

File tree

2 files changed

+118
-4
lines changed

2 files changed

+118
-4
lines changed

src/scportrait/data/_datasets.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,20 @@ def _get_remote_dataset(
3030
data_dir.mkdir(parents=True, exist_ok=True)
3131
save_path = data_dir / dataset
3232

33-
if force_download:
34-
_download(url=url, output_path=str(save_path), output_file_name=outfile_name, archive_format=archive_format)
35-
elif not save_path.exists():
36-
_download(url=url, output_path=str(save_path), output_file_name=outfile_name, archive_format=archive_format)
33+
dataset_exists = save_path.exists()
34+
expected_path = save_path / name if name is not None else None
35+
missing_expected_file = dataset_exists and expected_path is not None and not expected_path.exists()
36+
should_download = force_download or not dataset_exists or missing_expected_file
37+
38+
if should_download:
39+
_download(
40+
url=url,
41+
output_path=str(save_path),
42+
output_file_name=outfile_name,
43+
archive_format=archive_format,
44+
overwrite=force_download or missing_expected_file,
45+
)
46+
3747
if name is None:
3848
return save_path
3949
else:
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from __future__ import annotations
2+
3+
import scportrait.data._datasets as datasets
4+
5+
6+
def test_get_remote_dataset_redownloads_when_expected_file_missing(tmp_path, monkeypatch):
7+
data_root = tmp_path / "data"
8+
save_path = data_root / "example_dataset"
9+
save_path.mkdir(parents=True)
10+
11+
captured = {}
12+
13+
def _fake_download(**kwargs):
14+
captured.update(kwargs)
15+
16+
monkeypatch.setattr(datasets, "get_data_dir", lambda: data_root)
17+
monkeypatch.setattr(datasets, "_download", _fake_download)
18+
19+
returned = datasets._get_remote_dataset(
20+
dataset="example_dataset",
21+
url="https://example.com/file.dat",
22+
name="file.dat",
23+
archive_format=None,
24+
outfile_name="file.dat",
25+
)
26+
27+
assert returned == save_path / "file.dat"
28+
assert captured["output_path"] == str(save_path)
29+
assert captured["output_file_name"] == "file.dat"
30+
assert captured["overwrite"] is True
31+
32+
33+
def test_get_remote_dataset_skips_download_when_expected_file_exists(tmp_path, monkeypatch):
34+
data_root = tmp_path / "data"
35+
save_path = data_root / "example_dataset"
36+
save_path.mkdir(parents=True)
37+
(save_path / "file.dat").write_text("ok")
38+
39+
def _unexpected_download(**kwargs):
40+
raise AssertionError("_download should not be called when expected file already exists")
41+
42+
monkeypatch.setattr(datasets, "get_data_dir", lambda: data_root)
43+
monkeypatch.setattr(datasets, "_download", _unexpected_download)
44+
45+
returned = datasets._get_remote_dataset(
46+
dataset="example_dataset",
47+
url="https://example.com/file.dat",
48+
name="file.dat",
49+
archive_format=None,
50+
outfile_name="file.dat",
51+
)
52+
53+
assert returned == save_path / "file.dat"
54+
55+
56+
def test_get_remote_dataset_downloads_when_dataset_dir_missing(tmp_path, monkeypatch):
57+
data_root = tmp_path / "data"
58+
save_path = data_root / "example_dataset"
59+
60+
captured = {}
61+
62+
def _fake_download(**kwargs):
63+
captured.update(kwargs)
64+
65+
monkeypatch.setattr(datasets, "get_data_dir", lambda: data_root)
66+
monkeypatch.setattr(datasets, "_download", _fake_download)
67+
68+
returned = datasets._get_remote_dataset(
69+
dataset="example_dataset",
70+
url="https://example.com/archive.zip",
71+
name=None,
72+
archive_format="zip",
73+
outfile_name=None,
74+
)
75+
76+
assert returned == save_path
77+
assert captured["output_path"] == str(save_path)
78+
assert captured["overwrite"] is False
79+
80+
81+
def test_get_remote_dataset_named_file_missing_dir_does_not_force_overwrite(tmp_path, monkeypatch):
82+
data_root = tmp_path / "data"
83+
save_path = data_root / "example_dataset"
84+
85+
captured = {}
86+
87+
def _fake_download(**kwargs):
88+
captured.update(kwargs)
89+
90+
monkeypatch.setattr(datasets, "get_data_dir", lambda: data_root)
91+
monkeypatch.setattr(datasets, "_download", _fake_download)
92+
93+
returned = datasets._get_remote_dataset(
94+
dataset="example_dataset",
95+
url="https://example.com/file.dat",
96+
name="file.dat",
97+
archive_format=None,
98+
outfile_name="file.dat",
99+
)
100+
101+
assert returned == save_path / "file.dat"
102+
assert captured["output_path"] == str(save_path)
103+
assert captured["output_file_name"] == "file.dat"
104+
assert captured["overwrite"] is False

0 commit comments

Comments
 (0)