Skip to content

Commit fad278d

Browse files
committed
tests are almost good to go
1 parent f15f446 commit fad278d

File tree

1 file changed

+33
-14
lines changed

1 file changed

+33
-14
lines changed

services/dask-sidecar/tests/unit/test_tasks.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from pytest_mock.plugin import MockerFixture
4545
from pytest_simcore.helpers.typing_env import EnvVarsDict
4646
from servicelib.rabbitmq._client import RabbitMQClient
47+
from servicelib.rabbitmq._constants import BIND_TO_ALL_TOPICS
4748
from settings_library.s3 import S3Settings
4849
from simcore_service_dask_sidecar.computational_sidecar.docker_utils import (
4950
LEGACY_SERVICE_LOG_FILE_NAME,
@@ -67,7 +68,9 @@
6768

6869

6970
@pytest.fixture()
70-
def dask_subsystem_mock(mocker: MockerFixture) -> dict[str, mock.Mock]:
71+
def dask_subsystem_mock(
72+
mocker: MockerFixture, create_rabbitmq_client: Callable[[str], RabbitMQClient]
73+
) -> dict[str, mock.Mock]:
7174
# mock dask client
7275
dask_client_mock = mocker.patch("distributed.Client", autospec=True)
7376

@@ -107,6 +110,13 @@ def dask_subsystem_mock(mocker: MockerFixture) -> dict[str, mock.Mock]:
107110
autospec=True,
108111
return_value=False,
109112
)
113+
# mock dask rabbitmq plugin
114+
mock_rabbitmq_client = create_rabbitmq_client("pytest_dask_sidecar_logs_publisher")
115+
mocker.patch(
116+
"simcore_service_dask_sidecar.dask_utils.get_rabbitmq_client",
117+
autospec=True,
118+
return_value=mock_rabbitmq_client,
119+
)
110120

111121
return {
112122
"dask_client": dask_client_mock,
@@ -151,10 +161,6 @@ def sidecar_params(self) -> dict[str, Any]:
151161
}
152162

153163

154-
pytest_simcore_core_services_selection = ["postgres"]
155-
pytest_simcore_ops_services_selection = []
156-
157-
158164
def _bash_check_env_exist(variable_name: str, variable_value: str) -> list[str]:
159165
return [
160166
f"if [ -z ${{{variable_name}+x}} ];then echo {variable_name} does not exist && exit 9;fi",
@@ -472,7 +478,10 @@ async def log_rabbit_client_parser(
472478
client = create_rabbitmq_client("dask_sidecar_pytest_logs_consumer")
473479
mock = mocker.AsyncMock(return_value=True)
474480
queue_name, _ = await client.subscribe(
475-
LoggerRabbitMessage.get_channel_name(), mock, exclusive_queue=False
481+
LoggerRabbitMessage.get_channel_name(),
482+
mock,
483+
exclusive_queue=False,
484+
topics=[BIND_TO_ALL_TOPICS],
476485
)
477486
yield mock
478487
await client.unsubscribe(queue_name)
@@ -501,7 +510,7 @@ def test_run_computational_sidecar_real_fct(
501510
dask_subsystem_mock["dask_event_publish"].assert_any_call(
502511
name=event.topic_name()
503512
)
504-
log_rabbit_client_parser.assert_called_once()
513+
assert log_rabbit_client_parser.called
505514

506515
# check that the task produces expected logs
507516
for log in sleeper_task.expected_logs:
@@ -584,12 +593,13 @@ def progress_sub(dask_client: distributed.Client) -> distributed.Sub:
584593
"integration_version, boot_mode", [("1.0.0", BootMode.CPU)], indirect=True
585594
)
586595
async def test_run_computational_sidecar_dask(
587-
dask_client: distributed.Client,
596+
app_environment: EnvVarsDict,
588597
sleeper_task: ServiceExampleParam,
589598
progress_sub: distributed.Sub,
590599
mocked_get_image_labels: mock.Mock,
591600
s3_settings: S3Settings,
592601
log_rabbit_client_parser: mock.AsyncMock,
602+
dask_client: distributed.Client,
593603
):
594604
future = dask_client.submit(
595605
run_computational_sidecar,
@@ -614,9 +624,14 @@ async def test_run_computational_sidecar_dask(
614624
), "ordering of progress values incorrectly sorted!"
615625
assert worker_progresses[0] == 0, "missing/incorrect initial progress value"
616626
assert worker_progresses[-1] == 1, "missing/incorrect final progress value"
617-
log_rabbit_client_parser.assert_called_once()
618-
# worker_logs = [TaskLogEvent.model_validate_json(msg).log for msg in log_sub.buffer]
619-
worker_logs = []
627+
await asyncio.sleep(5)
628+
assert log_rabbit_client_parser.called
629+
worker_logs = [
630+
message
631+
for msg in log_rabbit_client_parser.call_args_list
632+
for message in LoggerRabbitMessage.model_validate_json(msg.args[0]).messages
633+
]
634+
620635
print(f"<-- we got {len(worker_logs)} lines of logs")
621636

622637
for log in sleeper_task.expected_logs:
@@ -688,9 +703,13 @@ async def test_run_computational_sidecar_dask_does_not_lose_messages_with_pubsub
688703
assert worker_progresses[0] == 0, "missing/incorrect initial progress value"
689704
assert worker_progresses[-1] == 1, "missing/incorrect final progress value"
690705

691-
log_rabbit_client_parser.assert_called_once()
692-
# worker_logs = [TaskLogEvent.model_validate_json(msg).log for msg in log_sub.buffer]
693-
worker_logs = []
706+
await asyncio.sleep(5)
707+
assert log_rabbit_client_parser.called
708+
worker_logs = [
709+
message
710+
for msg in log_rabbit_client_parser.call_args_list
711+
for message in LoggerRabbitMessage.model_validate_json(msg.args[0]).messages
712+
]
694713
# check all the awaited logs are in there
695714
filtered_worker_logs = filter(lambda log: "This is iteration" in log, worker_logs)
696715
assert len(list(filtered_worker_logs)) == NUMBER_OF_LOGS

0 commit comments

Comments
 (0)