Skip to content

Commit 8004632

Browse files
jewettaijfcyaugenst-flex
authored andcommitted
files are downloaded atomically using path.rename()
1 parent 5f68fdf commit 8004632

File tree

3 files changed

+157
-21
lines changed

3 files changed

+157
-21
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626
or `Absorber` classes (or when invoking `pml()`, `stable_pml()`, or `absorber()` functions)
2727
with fewer layers than recommended.
2828
- Warnings and error messages originating from `Structure`, `Source`, or `Monitor` classes now refer to problematic objects by their user-supplied `name` attribute, alongside their index.
29+
- File downloads are atomic. Interruptions or failures during download will no longer result in incomplete files.
2930

3031
### Fixed
3132
- Arrow lengths are now scaled consistently in the X and Y directions, and their lengths no longer exceed the height of the plot window.

tests/test_web/test_s3utils.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
from __future__ import annotations
2+
3+
from unittest.mock import MagicMock
4+
5+
import pytest
6+
7+
import tidy3d
8+
from tidy3d.web.core import s3utils
9+
10+
11+
@pytest.fixture
12+
def mock_S3STSToken(monkeypatch):
13+
mock_token = MagicMock()
14+
mock_token.cloud_path = ""
15+
mock_token.user_credential = ""
16+
mock_token.get_bucket = lambda: ""
17+
mock_token.get_s3_key = lambda: ""
18+
mock_token.is_expired = lambda: False
19+
mock_token.get_client = lambda: tidy3d.web.core.s3utils.boto3.client()
20+
monkeypatch.setattr(
21+
target=tidy3d.web.core.s3utils, name="_S3STSToken", value=MagicMock(return_value=mock_token)
22+
)
23+
return mock_token
24+
25+
26+
@pytest.fixture
27+
def mock_get_s3_sts_token(monkeypatch):
28+
def _mock_get_s3_sts_token(resource_id, remote_filename):
29+
return s3utils._S3STSToken(resource_id, remote_filename)
30+
31+
monkeypatch.setattr(
32+
target=tidy3d.web.core.s3utils, name="get_s3_sts_token", value=_mock_get_s3_sts_token
33+
)
34+
return _mock_get_s3_sts_token
35+
36+
37+
@pytest.fixture
38+
def mock_s3_client(monkeypatch):
39+
"""
40+
Fixture that provides a generic mock S3 client.
41+
Method-specific side_effects are omitted here and are specified later in the unit tests.
42+
"""
43+
mock_client = MagicMock()
44+
# Patch the `client` as it is imported within `tidy3d.web.core.s3utils.boto3` so that
45+
# whenever it's invoked (for example with "s3"), it returns our `mock_client`.
46+
monkeypatch.setattr(
47+
target=tidy3d.web.core.s3utils.boto3,
48+
name="client",
49+
value=MagicMock(return_value=mock_client),
50+
)
51+
return mock_client
52+
53+
54+
def test_download_s3_file_success(mock_s3_client, mock_get_s3_sts_token, mock_S3STSToken, tmp_path):
55+
"""Tests a successful download."""
56+
destination_path = tmp_path / "downloaded_file.txt"
57+
expected_content = "abcdefg"
58+
59+
def simulate_download_success(Bucket, Key, Filename, Callback, Config, **kwargs):
60+
with open(Filename, "w") as f:
61+
f.write(expected_content)
62+
return None
63+
64+
mock_s3_client.download_file.side_effect = simulate_download_success
65+
mock_S3STSToken.get_bucket = lambda: "test-bucket"
66+
mock_S3STSToken.get_s3_key = lambda: "test-key"
67+
68+
s3utils.download_file(
69+
resource_id="1234567890",
70+
remote_filename=destination_path.name,
71+
to_file=str(destination_path),
72+
verbose=False,
73+
progress_callback=None,
74+
)
75+
76+
# Check that mock_s3_client.download_file() was invoked with the correct arguments.
77+
mock_s3_client.download_file.assert_called_once()
78+
call_args, call_kwargs = mock_s3_client.download_file.call_args
79+
assert call_kwargs["Bucket"] == "test-bucket"
80+
assert call_kwargs["Key"] == "test-key"
81+
assert call_kwargs["Filename"].endswith(s3utils.IN_TRANSIT_SUFFIX)
82+
assert destination_path.exists()
83+
with open(destination_path) as f:
84+
assert f.read() == expected_content
85+
for p in destination_path.parent.iterdir():
86+
assert not p.name.endswith(s3utils.IN_TRANSIT_SUFFIX) # no temporary files are present
87+
88+
89+
def test_download_s3_file_raises_oserror(
90+
mock_s3_client, mock_get_s3_sts_token, mock_S3STSToken, tmp_path
91+
):
92+
"""Tests download failing with an ``OSError`` (No space left on device)."""
93+
destination_path = tmp_path / "downloaded_file.txt"
94+
95+
def simulate_download_failure(Bucket, Key, Filename, Callback, Config, **kwargs):
96+
with open(Filename, "w") as f:
97+
f.write("abc")
98+
raise OSError("No space left on device")
99+
100+
mock_s3_client.download_file.side_effect = simulate_download_failure
101+
mock_S3STSToken.get_bucket = lambda: "test-bucket"
102+
mock_S3STSToken.get_s3_key = lambda: "test-key"
103+
104+
with pytest.raises(OSError, match="No space left on device"):
105+
s3utils.download_file(
106+
resource_id="1234567890",
107+
remote_filename=destination_path.name,
108+
to_file=str(destination_path),
109+
verbose=False,
110+
progress_callback=None,
111+
)
112+
113+
# Check that mock_s3_client.download_file() was invoked with the correct arguments.
114+
mock_s3_client.download_file.assert_called_once()
115+
call_args, call_kwargs = mock_s3_client.download_file.call_args
116+
assert call_kwargs["Bucket"] == "test-bucket"
117+
assert call_kwargs["Key"] == "test-key"
118+
assert call_kwargs["Filename"].endswith(s3utils.IN_TRANSIT_SUFFIX)
119+
# Since downloading failed, no new files should exist locally.
120+
assert not destination_path.exists()
121+
for p in destination_path.parent.iterdir():
122+
assert not p.name.endswith(s3utils.IN_TRANSIT_SUFFIX) # no temporary files are present

tidy3d/web/core/s3utils.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
from .file_util import extract_gzip_file
3030
from .http_util import http
3131

32+
IN_TRANSIT_SUFFIX = ".tmp"
33+
3234

3335
class _UserCredential(BaseModel):
3436
"""Stores information about user credentials."""
@@ -312,12 +314,12 @@ def download_file(
312314
# set to_file if None
313315
if not to_file:
314316
path = pathlib.Path(resource_id)
315-
to_file = path / remote_basename
317+
to_path = path / remote_basename
316318
else:
317-
to_file = pathlib.Path(to_file)
319+
to_path = pathlib.Path(to_file)
318320

319-
# make the leading directories in the 'to_file', if any
320-
to_file.parent.mkdir(parents=True, exist_ok=True)
321+
# make the leading directories in the 'to_path', if any
322+
to_path.parent.mkdir(parents=True, exist_ok=True)
321323

322324
def _download(_callback: Callable) -> None:
323325
"""Perform the download with a callback function.
@@ -327,14 +329,25 @@ def _download(_callback: Callable) -> None:
327329
_callback : Callable[[float], None]
328330
Callback function for download, accepts ``bytes_in_chunk``
329331
"""
330-
331-
client.download_file(
332-
Bucket=token.get_bucket(),
333-
Filename=str(to_file),
334-
Key=token.get_s3_key(),
335-
Callback=_callback,
336-
Config=_s3_config,
337-
)
332+
# Caller can assume the existence of the file means download succeeded.
333+
# So make sure this file does not exist until that assumption is true.
334+
to_path.unlink(missing_ok=True)
335+
# Download to a temporary file.
336+
try:
337+
fd, tmp_file_path_str = tempfile.mkstemp(suffix=IN_TRANSIT_SUFFIX, dir=to_path.parent)
338+
os.close(fd) # `tempfile.mkstemp()` creates and opens a randomly named file. close it.
339+
to_path_tmp = pathlib.Path(tmp_file_path_str)
340+
client.download_file(
341+
Bucket=token.get_bucket(),
342+
Filename=tmp_file_path_str,
343+
Key=token.get_s3_key(),
344+
Callback=_callback,
345+
Config=_s3_config,
346+
)
347+
to_path_tmp.rename(to_file)
348+
except Exception as e:
349+
to_path_tmp.unlink(missing_ok=True) # Delete incompletely downloaded file.
350+
raise e
338351

339352
if progress_callback is not None:
340353
_download(progress_callback)
@@ -355,7 +368,7 @@ def _callback(bytes_in_chunk):
355368
else:
356369
_download(lambda bytes_in_chunk: None)
357370

358-
return to_file
371+
return to_path
359372

360373

361374
def download_gz_file(
@@ -394,24 +407,24 @@ def download_gz_file(
394407

395408
# Otherwise, download and unzip
396409
# The tempfile is set as ``hdf5.gz`` so that the mock download in the webapi tests works
397-
tmp_file, tmp_file_path = tempfile.mkstemp(".hdf5.gz")
410+
tmp_file, tmp_file_path_str = tempfile.mkstemp(".hdf5.gz")
398411
os.close(tmp_file)
399412

400413
# make the leading directories in the 'to_file', if any
401-
to_file = pathlib.Path(to_file)
402-
to_file.parent.mkdir(parents=True, exist_ok=True)
414+
to_path = pathlib.Path(to_file)
415+
to_path.parent.mkdir(parents=True, exist_ok=True)
403416
try:
404417
download_file(
405418
resource_id,
406419
remote_filename,
407-
to_file=tmp_file_path,
420+
to_file=tmp_file_path_str,
408421
verbose=verbose,
409422
progress_callback=progress_callback,
410423
)
411-
if os.path.exists(tmp_file_path):
412-
extract_gzip_file(tmp_file_path, to_file)
424+
if os.path.exists(tmp_file_path_str):
425+
extract_gzip_file(tmp_file_path_str, to_path)
413426
else:
414427
raise WebError(f"Failed to download and extract '{remote_filename}'.")
415428
finally:
416-
os.unlink(tmp_file_path)
417-
return to_file
429+
os.unlink(tmp_file_path_str)
430+
return to_path

0 commit comments

Comments
 (0)