|
| 1 | +import multiprocessing |
1 | 2 | import os |
2 | | -import queue |
3 | | -import threading |
| 3 | +import random |
4 | 4 | import time |
5 | 5 |
|
6 | 6 | from pathlib import Path |
|
17 | 17 | ) |
18 | 18 |
|
19 | 19 |
|
| 20 | +def _concurrent_download_worker( |
| 21 | + url, |
| 22 | + cache_root, |
| 23 | + result_queue, |
| 24 | + process_names_emitting_lock_warnings, |
| 25 | + process_names_calling_requests_get, |
| 26 | + request_count, |
| 27 | +): |
| 28 | + """Download the file in a process and store the result in a queue.""" |
| 29 | + |
| 30 | + def mock_get(*args, **kwargs): |
| 31 | + time.sleep(6) |
| 32 | + response = MagicMock() |
| 33 | + response.content = b"content" |
| 34 | + response.status_code = 200 |
| 35 | + process_name = multiprocessing.current_process().name |
| 36 | + process_names_calling_requests_get.append(process_name) |
| 37 | + request_count.value += 1 |
| 38 | + return response |
| 39 | + |
| 40 | + def mock_warning(msg, *args, **kwargs): |
| 41 | + if "Failed to acquire lock" in msg: |
| 42 | + process_names_emitting_lock_warnings.append( |
| 43 | + multiprocessing.current_process().name |
| 44 | + ) |
| 45 | + |
| 46 | + # Randomize which process calls cached_download_file first |
| 47 | + time.sleep(random.uniform(0, 0.1)) |
| 48 | + |
| 49 | + with patch("conda_lock.lookup_cache.requests.get", side_effect=mock_get), patch( |
| 50 | + "conda_lock.lookup_cache.logger.warning", side_effect=mock_warning |
| 51 | + ): |
| 52 | + result = cached_download_file( |
| 53 | + url, cache_subdir_name="test_cache", cache_root=cache_root |
| 54 | + ) |
| 55 | + result_queue.put(result) |
| 56 | + |
| 57 | + |
20 | 58 | @pytest.fixture |
21 | 59 | def mock_cache_dir(tmp_path): |
22 | 60 | cache_dir = tmp_path / "cache" / "test_cache" |
@@ -252,71 +290,66 @@ def wrapped_get(*args, **kwargs): |
252 | 290 |
|
253 | 291 |
|
254 | 292 | def test_concurrent_cached_download_file(tmp_path): |
255 | | - """Test concurrent access to cached_download_file with 5 threads.""" |
| 293 | + """Test concurrent access to cached_download_file with 5 processes.""" |
256 | 294 | url = "https://example.com/test.json" |
257 | | - results: queue.Queue[bytes] = queue.Queue() |
258 | | - thread_names_emitting_lock_warnings: queue.Queue[str] = queue.Queue() |
259 | | - thread_names_calling_requests_get: queue.Queue[str] = queue.Queue() |
260 | 295 |
|
261 | | - def mock_get(*args, **kwargs): |
262 | | - time.sleep(6) |
263 | | - response = MagicMock() |
264 | | - response.content = b"content" |
265 | | - response.status_code = 200 |
266 | | - thread_name = threading.current_thread().name |
267 | | - thread_names_calling_requests_get.put(thread_name) |
268 | | - return response |
269 | | - |
270 | | - def download_file(result_queue): |
271 | | - """Download the file in a thread and store the result in a queue.""" |
272 | | - import random |
273 | | - |
274 | | - # Randomize which thread calls cached_download_file first |
275 | | - time.sleep(random.uniform(0, 0.1)) |
276 | | - result = cached_download_file( |
277 | | - url, cache_subdir_name="test_cache", cache_root=tmp_path |
278 | | - ) |
279 | | - result_queue.put(result) |
280 | | - |
281 | | - with patch("requests.get", side_effect=mock_get) as mock_get, patch( |
282 | | - "conda_lock.lookup_cache.logger" |
283 | | - ) as mock_logger: |
284 | | - # Set up the logger to record which threads emit warnings |
285 | | - def mock_warning(msg, *args, **kwargs): |
286 | | - if "Failed to acquire lock" in msg: |
287 | | - thread_names_emitting_lock_warnings.put(threading.current_thread().name) |
288 | | - |
289 | | - mock_logger.warning.side_effect = mock_warning |
290 | | - |
291 | | - # Create and start 5 threads |
292 | | - thread_names = [f"CachedDownloadFileThread-{i}" for i in range(5)] |
293 | | - threads = [ |
294 | | - threading.Thread(target=download_file, args=(results,), name=thread_name) |
295 | | - for thread_name in thread_names |
| 296 | + # Use multiprocessing Manager to share state between processes |
| 297 | + with multiprocessing.Manager() as manager: |
| 298 | + results = manager.Queue() |
| 299 | + process_names_emitting_lock_warnings = manager.list() |
| 300 | + process_names_calling_requests_get = manager.list() |
| 301 | + request_count = manager.Value("i", 0) |
| 302 | + |
| 303 | + # Create and start 5 processes |
| 304 | + process_names = [f"CachedDownloadFileProcess-{i}" for i in range(5)] |
| 305 | + processes = [ |
| 306 | + multiprocessing.Process( |
| 307 | + target=_concurrent_download_worker, |
| 308 | + args=( |
| 309 | + url, |
| 310 | + tmp_path, |
| 311 | + results, |
| 312 | + process_names_emitting_lock_warnings, |
| 313 | + process_names_calling_requests_get, |
| 314 | + request_count, |
| 315 | + ), |
| 316 | + name=process_name, |
| 317 | + ) |
| 318 | + for process_name in process_names |
296 | 319 | ] |
297 | | - for thread in threads: |
298 | | - thread.start() |
299 | | - for thread in threads: |
300 | | - thread.join() |
| 320 | + for process in processes: |
| 321 | + process.start() |
| 322 | + for process in processes: |
| 323 | + process.join() |
301 | 324 |
|
302 | 325 | # Collect results from the queue |
303 | | - assert results.qsize() == len(threads) |
304 | | - assert all(result == b"content" for result in results.queue) |
| 326 | + assert results.qsize() == len(processes) |
| 327 | + results_list = [] |
| 328 | + while not results.empty(): |
| 329 | + results_list.append(results.get()) |
| 330 | + assert all(result == b"content" for result in results_list) |
305 | 331 |
|
306 | | - # We expect one thread to have made the request and the other four |
| 332 | + # We expect one process to have made the request and the other four |
307 | 333 | # to have emitted warnings. |
| 334 | + process_names_calling_requests_get_list = list( |
| 335 | + process_names_calling_requests_get |
| 336 | + ) |
| 337 | + process_names_emitting_lock_warnings_list = list( |
| 338 | + process_names_emitting_lock_warnings |
| 339 | + ) |
| 340 | + |
308 | 341 | assert ( |
309 | | - thread_names_calling_requests_get.qsize() |
| 342 | + len(process_names_calling_requests_get_list) |
310 | 343 | == 1 |
311 | | - == len(set(thread_names_calling_requests_get.queue)) |
312 | | - == mock_get.call_count |
313 | | - ), f"{thread_names_calling_requests_get.queue=}" |
| 344 | + == len(set(process_names_calling_requests_get_list)) |
| 345 | + == request_count.value |
| 346 | + ), f"{process_names_calling_requests_get_list=}" |
314 | 347 | assert ( |
315 | | - thread_names_emitting_lock_warnings.qsize() |
| 348 | + len(process_names_emitting_lock_warnings_list) |
316 | 349 | == 4 |
317 | | - == len(set(thread_names_emitting_lock_warnings.queue)) |
318 | | - ), f"{thread_names_emitting_lock_warnings.queue=}" |
319 | | - assert set(thread_names) == set( |
320 | | - thread_names_calling_requests_get.queue |
321 | | - + thread_names_emitting_lock_warnings.queue |
| 350 | + == len(set(process_names_emitting_lock_warnings_list)) |
| 351 | + ), f"{process_names_emitting_lock_warnings_list=}" |
| 352 | + assert set(process_names) == set( |
| 353 | + process_names_calling_requests_get_list |
| 354 | + + process_names_emitting_lock_warnings_list |
322 | 355 | ) |
0 commit comments