|
8 | 8 | import json |
9 | 9 | import logging |
10 | 10 | import re |
| 11 | +import threading |
11 | 12 | from collections.abc import AsyncIterator, Callable, Coroutine, Iterable |
12 | 13 |
|
13 | 14 | # copied out from dask |
@@ -475,16 +476,55 @@ def mocked_get_image_labels( |
475 | 476 | async def log_rabbit_client_parser( |
476 | 477 | create_rabbitmq_client: Callable[[str], RabbitMQClient], mocker: MockerFixture |
477 | 478 | ) -> AsyncIterator[mock.AsyncMock]: |
478 | | - client = create_rabbitmq_client("dask_sidecar_pytest_logs_consumer") |
479 | | - mock = mocker.AsyncMock(return_value=True) |
480 | | - queue_name, _ = await client.subscribe( |
481 | | - LoggerRabbitMessage.get_channel_name(), |
482 | | - mock, |
483 | | - exclusive_queue=False, |
484 | | - topics=[BIND_TO_ALL_TOPICS], |
| 479 | + # Create a threading event to track when subscription is ready |
| 480 | + ready_event = threading.Event() |
| 481 | + shutdown_event = threading.Event() |
| 482 | + the_mock = mocker.AsyncMock(return_value=True) |
| 483 | + |
| 484 | + # Worker function to process messages in a separate thread |
| 485 | + def message_processor(a_mock: mock.AsyncMock): |
| 486 | + loop = asyncio.new_event_loop() |
| 487 | + asyncio.set_event_loop(loop) |
| 488 | + |
| 489 | + client = create_rabbitmq_client("dask_sidecar_pytest_logs_consumer") |
| 490 | + |
| 491 | + async def subscribe_and_process(a_mock: mock.AsyncMock): |
| 492 | + queue_name, _ = await client.subscribe( |
| 493 | + LoggerRabbitMessage.get_channel_name(), |
| 494 | + a_mock, |
| 495 | + exclusive_queue=False, |
| 496 | + topics=[BIND_TO_ALL_TOPICS], |
| 497 | + ) |
| 498 | + ready_event.set() |
| 499 | + |
| 500 | + # Wait until the test is done |
| 501 | + while not shutdown_event.is_set(): |
| 502 | + await asyncio.sleep(0.1) |
| 503 | + |
| 504 | + # Cleanup |
| 505 | + await client.unsubscribe(queue_name) |
| 506 | + |
| 507 | + loop.run_until_complete(subscribe_and_process(a_mock)) |
| 508 | + loop.run_until_complete(client.close()) |
| 509 | + loop.close() |
| 510 | + |
| 511 | + # Start the worker thread |
| 512 | + worker = threading.Thread( |
| 513 | + target=message_processor, kwargs={"a_mock": the_mock}, daemon=False |
485 | 514 | ) |
486 | | - yield mock |
487 | | - await client.unsubscribe(queue_name) |
| 515 | + worker.start() |
| 516 | + |
| 517 | + # Wait for subscription to be ready |
| 518 | + assert ready_event.wait(timeout=10), "Failed to initialize RabbitMQ subscription" |
| 519 | + |
| 520 | + try: |
| 521 | + yield the_mock |
| 522 | + finally: |
| 523 | + # Signal the worker thread to shut down |
| 524 | + shutdown_event.set() |
| 525 | + worker.join(timeout=5) |
| 526 | + if worker.is_alive(): |
| 527 | + _logger.warning("RabbitMQ worker thread did not terminate properly") |
488 | 528 |
|
489 | 529 |
|
490 | 530 | def test_run_computational_sidecar_real_fct( |
@@ -670,7 +710,7 @@ async def test_run_computational_sidecar_dask_does_not_lose_messages_with_pubsub |
670 | 710 | log_rabbit_client_parser: mock.AsyncMock, |
671 | 711 | ): |
672 | 712 | mocked_get_image_labels.assert_not_called() |
673 | | - NUMBER_OF_LOGS = 20000 |
| 713 | + NUMBER_OF_LOGS = 200 |
674 | 714 | future = dask_client.submit( |
675 | 715 | run_computational_sidecar, |
676 | 716 | **sidecar_task( |
|
0 commit comments