Skip to content

Commit 6532762

Browse files
authored
Merge pull request #811 from maresb/double-check-on-lookup-cache
Refactor lookup cache test to use multiprocessing instead of threading
2 parents 38e425f + 9904a52 commit 6532762

File tree

4 files changed

+131
-76
lines changed

4 files changed

+131
-76
lines changed

conda_lock/conda_lock.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,6 @@ def run_lock(
11931193
_conda_exe = determine_conda_executable(
11941194
conda_exe, mamba=mamba, micromamba=micromamba
11951195
)
1196-
logger.debug(f"Using conda executable: {_conda_exe}")
11971196
version_info = subprocess.check_output(
11981197
[_conda_exe, "--version"], encoding="utf-8"
11991198
).strip()

conda_lock/invoke_conda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def determine_conda_executable(
6262
if is_micromamba(candidate):
6363
if determine_micromamba_version(str(candidate)) < Version("0.17"):
6464
mamba_root_prefix()
65+
logger.debug(f"Found conda executable: {candidate}")
6566
return candidate
6667
raise RuntimeError("Could not find conda (or compatible) executable")
6768

conda_lock/lookup_cache.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,17 @@ def cached_download_file(
5757
cache.mkdir(parents=True, exist_ok=True)
5858
clear_old_files_from_cache(cache, max_age_seconds=max_age_seconds)
5959

60-
destination_lock = (cache / cached_filename_for_url(url)).with_suffix(".lock")
60+
destination = cache / cached_filename_for_url(url)
61+
destination_lock = destination.with_suffix(".lock")
6162

6263
# Wait for any other process to finish downloading the file.
6364
# This way we can use the result from the current download without
6465
# spawning multiple concurrent downloads.
6566
while True:
6667
try:
68+
logger.debug(f"Attempting to acquire lock on {destination_lock}")
6769
with FileLock(str(destination_lock), timeout=5):
70+
logger.debug(f"Successfully acquired lock on {destination_lock}")
6871
return _download_to_or_read_from_cache(
6972
url,
7073
cache=cache,
@@ -77,6 +80,33 @@ def cached_download_file(
7780
)
7881

7982

83+
def _is_cached_file_fresh(
84+
destination: Path, dont_check_if_newer_than_seconds: float
85+
) -> bool:
86+
"""Check if a cached file exists and is fresh enough to use without checking.
87+
88+
(In this context, "checking" means that later, beyond the scope of this function,
89+
we will query the server with an ETag to see if the file has changed or if we
90+
get a 304 Not Modified response.)
91+
92+
A file is "fresh" if its age is positive and less than
93+
`dont_check_if_newer_than_seconds`.
94+
95+
Returns True if the file is fresh, False otherwise.
96+
"""
97+
if destination.is_file():
98+
age_seconds = get_age_seconds(destination)
99+
if age_seconds is None:
100+
raise RuntimeError(f"Error checking age of {destination}")
101+
if 0 <= age_seconds < dont_check_if_newer_than_seconds:
102+
logger.debug(
103+
f"Using cached file {destination} of age {age_seconds}s "
104+
f"without checking for updates"
105+
)
106+
return True
107+
return False
108+
109+
80110
def _download_to_or_read_from_cache(
81111
url: str, *, cache: Path, dont_check_if_newer_than_seconds: float
82112
) -> bytes:
@@ -93,22 +123,14 @@ def _download_to_or_read_from_cache(
93123
destination_etag = destination.with_suffix(".etag")
94124
request_headers = {"User-Agent": "conda-lock"}
95125
# Return the contents immediately if the file is fresh
96-
if destination.is_file():
97-
age_seconds = get_age_seconds(destination)
98-
if age_seconds is None:
99-
raise RuntimeError(f"Error checking age of {destination}")
100-
if 0 <= age_seconds < dont_check_if_newer_than_seconds:
101-
logger.debug(
102-
f"Using cached mapping {destination} of age {age_seconds}s "
103-
f"without checking for updates"
104-
)
105-
return destination.read_bytes()
106-
# Add the ETag from the last download, if it exists, to the headers.
107-
# The ETag is used to avoid downloading the file if it hasn't changed remotely.
108-
# Otherwise, download the file and cache the contents and ETag.
109-
if destination_etag.is_file():
110-
old_etag = destination_etag.read_text().strip()
111-
request_headers["If-None-Match"] = old_etag
126+
if _is_cached_file_fresh(destination, dont_check_if_newer_than_seconds):
127+
return destination.read_bytes()
128+
# Add the ETag from the last download, if it exists, to the headers.
129+
# The ETag is used to avoid downloading the file if it hasn't changed remotely.
130+
# Otherwise, download the file and cache the contents and ETag.
131+
if destination.is_file() and destination_etag.is_file():
132+
old_etag = destination_etag.read_text().strip()
133+
request_headers["If-None-Match"] = old_etag
112134
# Download the file and cache the result.
113135
logger.debug(f"Requesting {url}")
114136
res = requests.get(url, headers=request_headers)

tests/test_lookup_cache.py

Lines changed: 91 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import multiprocessing
12
import os
2-
import queue
3-
import threading
3+
import random
44
import time
55

66
from pathlib import Path
@@ -17,6 +17,44 @@
1717
)
1818

1919

20+
def _concurrent_download_worker(
21+
url,
22+
cache_root,
23+
result_queue,
24+
process_names_emitting_lock_warnings,
25+
process_names_calling_requests_get,
26+
request_count,
27+
):
28+
"""Download the file in a process and store the result in a queue."""
29+
30+
def mock_get(*args, **kwargs):
31+
time.sleep(6)
32+
response = MagicMock()
33+
response.content = b"content"
34+
response.status_code = 200
35+
process_name = multiprocessing.current_process().name
36+
process_names_calling_requests_get.append(process_name)
37+
request_count.value += 1
38+
return response
39+
40+
def mock_warning(msg, *args, **kwargs):
41+
if "Failed to acquire lock" in msg:
42+
process_names_emitting_lock_warnings.append(
43+
multiprocessing.current_process().name
44+
)
45+
46+
# Randomize which process calls cached_download_file first
47+
time.sleep(random.uniform(0, 0.1))
48+
49+
with patch("conda_lock.lookup_cache.requests.get", side_effect=mock_get), patch(
50+
"conda_lock.lookup_cache.logger.warning", side_effect=mock_warning
51+
):
52+
result = cached_download_file(
53+
url, cache_subdir_name="test_cache", cache_root=cache_root
54+
)
55+
result_queue.put(result)
56+
57+
2058
@pytest.fixture
2159
def mock_cache_dir(tmp_path):
2260
cache_dir = tmp_path / "cache" / "test_cache"
@@ -252,71 +290,66 @@ def wrapped_get(*args, **kwargs):
252290

253291

254292
def test_concurrent_cached_download_file(tmp_path):
255-
"""Test concurrent access to cached_download_file with 5 threads."""
293+
"""Test concurrent access to cached_download_file with 5 processes."""
256294
url = "https://example.com/test.json"
257-
results: queue.Queue[bytes] = queue.Queue()
258-
thread_names_emitting_lock_warnings: queue.Queue[str] = queue.Queue()
259-
thread_names_calling_requests_get: queue.Queue[str] = queue.Queue()
260295

261-
def mock_get(*args, **kwargs):
262-
time.sleep(6)
263-
response = MagicMock()
264-
response.content = b"content"
265-
response.status_code = 200
266-
thread_name = threading.current_thread().name
267-
thread_names_calling_requests_get.put(thread_name)
268-
return response
269-
270-
def download_file(result_queue):
271-
"""Download the file in a thread and store the result in a queue."""
272-
import random
273-
274-
# Randomize which thread calls cached_download_file first
275-
time.sleep(random.uniform(0, 0.1))
276-
result = cached_download_file(
277-
url, cache_subdir_name="test_cache", cache_root=tmp_path
278-
)
279-
result_queue.put(result)
280-
281-
with patch("requests.get", side_effect=mock_get) as mock_get, patch(
282-
"conda_lock.lookup_cache.logger"
283-
) as mock_logger:
284-
# Set up the logger to record which threads emit warnings
285-
def mock_warning(msg, *args, **kwargs):
286-
if "Failed to acquire lock" in msg:
287-
thread_names_emitting_lock_warnings.put(threading.current_thread().name)
288-
289-
mock_logger.warning.side_effect = mock_warning
290-
291-
# Create and start 5 threads
292-
thread_names = [f"CachedDownloadFileThread-{i}" for i in range(5)]
293-
threads = [
294-
threading.Thread(target=download_file, args=(results,), name=thread_name)
295-
for thread_name in thread_names
296+
# Use multiprocessing Manager to share state between processes
297+
with multiprocessing.Manager() as manager:
298+
results = manager.Queue()
299+
process_names_emitting_lock_warnings = manager.list()
300+
process_names_calling_requests_get = manager.list()
301+
request_count = manager.Value("i", 0)
302+
303+
# Create and start 5 processes
304+
process_names = [f"CachedDownloadFileProcess-{i}" for i in range(5)]
305+
processes = [
306+
multiprocessing.Process(
307+
target=_concurrent_download_worker,
308+
args=(
309+
url,
310+
tmp_path,
311+
results,
312+
process_names_emitting_lock_warnings,
313+
process_names_calling_requests_get,
314+
request_count,
315+
),
316+
name=process_name,
317+
)
318+
for process_name in process_names
296319
]
297-
for thread in threads:
298-
thread.start()
299-
for thread in threads:
300-
thread.join()
320+
for process in processes:
321+
process.start()
322+
for process in processes:
323+
process.join()
301324

302325
# Collect results from the queue
303-
assert results.qsize() == len(threads)
304-
assert all(result == b"content" for result in results.queue)
326+
assert results.qsize() == len(processes)
327+
results_list = []
328+
while not results.empty():
329+
results_list.append(results.get())
330+
assert all(result == b"content" for result in results_list)
305331

306-
# We expect one thread to have made the request and the other four
332+
# We expect one process to have made the request and the other four
307333
# to have emitted warnings.
334+
process_names_calling_requests_get_list = list(
335+
process_names_calling_requests_get
336+
)
337+
process_names_emitting_lock_warnings_list = list(
338+
process_names_emitting_lock_warnings
339+
)
340+
308341
assert (
309-
thread_names_calling_requests_get.qsize()
342+
len(process_names_calling_requests_get_list)
310343
== 1
311-
== len(set(thread_names_calling_requests_get.queue))
312-
== mock_get.call_count
313-
), f"{thread_names_calling_requests_get.queue=}"
344+
== len(set(process_names_calling_requests_get_list))
345+
== request_count.value
346+
), f"{process_names_calling_requests_get_list=}"
314347
assert (
315-
thread_names_emitting_lock_warnings.qsize()
348+
len(process_names_emitting_lock_warnings_list)
316349
== 4
317-
== len(set(thread_names_emitting_lock_warnings.queue))
318-
), f"{thread_names_emitting_lock_warnings.queue=}"
319-
assert set(thread_names) == set(
320-
thread_names_calling_requests_get.queue
321-
+ thread_names_emitting_lock_warnings.queue
350+
== len(set(process_names_emitting_lock_warnings_list))
351+
), f"{process_names_emitting_lock_warnings_list=}"
352+
assert set(process_names) == set(
353+
process_names_calling_requests_get_list
354+
+ process_names_emitting_lock_warnings_list
322355
)

0 commit comments

Comments
 (0)