Skip to content

Commit 1f444ee

Browse files
authored
fix(cache_pull.py): avoid data race (#179)
The `session.get` method is not thread safe. Use thread local to provide each worker with its own session. It does not allow to maximize connections reuse as we originally intended, but, regardless, each thread will reuse persistent connections whenever possible. Also, we process input sequentially over the inputs and in parallel over threads, with batch submission and no pauses, so it's probably ~equivalent to before. Tests continue to WAI and do not require changes because the mocks mock `requests.Session`. Closes #178
1 parent 1f2ec1f commit 1f444ee

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

library/src/iqb/cli/cache_pull.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,15 @@
2828
from ..pipeline.cache import data_dir_or_default
2929
from .cache import cache
3030

31+
_thread_local = threading.local()
32+
33+
34+
def _get_session() -> requests.Session:
35+
"""Return a per-thread requests.Session for connection reuse."""
36+
if not hasattr(_thread_local, "session"):
37+
_thread_local.session = requests.Session()
38+
return _thread_local.session
39+
3140

3241
def _short_name(file: str) -> str:
3342
"""Extract a short display name from a cache path."""
@@ -43,7 +52,6 @@ def _now() -> str:
4352
def _download_one(
4453
entry: DiffEntry,
4554
data_dir: Path,
46-
session: requests.Session,
4755
progress: Progress,
4856
) -> dict[str, object]:
4957
"""Download a single file, verify SHA256, atomic-replace. Returns a metrics span."""
@@ -61,6 +69,7 @@ def _download_one(
6169
try:
6270
with TemporaryDirectory(dir=dest.parent) as tmp_dir:
6371
tmp_file = Path(tmp_dir) / dest.name
72+
session = _get_session()
6473
resp = session.get(entry.url, stream=True)
6574
resp.raise_for_status()
6675
cl = resp.headers.get("Content-Length")
@@ -120,7 +129,6 @@ def pull(data_dir: str | None, force: bool, jobs: int) -> None:
120129
click.echo("Nothing to download.")
121130
return
122131

123-
session = requests.Session()
124132
failed: list[tuple[str, str]] = []
125133
spans: list[dict[str, object]] = []
126134
t0 = time.monotonic()
@@ -134,8 +142,7 @@ def pull(data_dir: str | None, force: bool, jobs: int) -> None:
134142
ThreadPoolExecutor(max_workers=jobs) as pool,
135143
):
136144
futures = {
137-
pool.submit(_download_one, entry, resolved, session, progress): entry
138-
for entry in targets
145+
pool.submit(_download_one, entry, resolved, progress): entry for entry in targets
139146
}
140147
for future in as_completed(futures):
141148
entry = futures[future]

0 commit comments

Comments
 (0)