Skip to content

Commit 2df7621

Browse files
committed
works
1 parent 8cadbbe commit 2df7621

File tree

1 file changed

+34
-114
lines changed

1 file changed

+34
-114
lines changed

services/dask-sidecar/tests/unit/test_utils_dask.py

Lines changed: 34 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
import concurrent.futures
99
import logging
1010
import time
11-
from collections.abc import AsyncIterator, Callable, Coroutine
1211
from typing import Any
1312
from unittest import mock
1413

1514
import distributed
1615
import pytest
16+
from common_library.async_tools import maybe_await
1717
from dask_task_models_library.container_tasks.errors import TaskCancelledError
1818
from dask_task_models_library.container_tasks.events import TaskProgressEvent
1919
from dask_task_models_library.container_tasks.io import TaskCancelEventName
@@ -28,7 +28,6 @@
2828
publish_event,
2929
)
3030
from tenacity import Retrying
31-
from tenacity.asyncio import AsyncRetrying
3231
from tenacity.retry import retry_if_exception_type
3332
from tenacity.stop import stop_after_delay
3433
from tenacity.wait import wait_fixed
@@ -41,12 +40,24 @@
4140
]
4241

4342

44-
@pytest.mark.parametrize("handler", [mock.Mock(), mock.AsyncMock()])
45-
async def test_publish_event(
43+
@pytest.fixture(params=["sync-dask-client", "async-dask-client"])
44+
def dask_client_multi(
45+
request: pytest.FixtureRequest,
4646
dask_client: distributed.Client,
47+
async_dask_client: distributed.Client,
48+
) -> distributed.Client:
49+
if request.param == "sync-dask-client":
50+
return dask_client
51+
return async_dask_client
52+
53+
54+
@pytest.mark.parametrize(
55+
"handler", [mock.Mock(), mock.AsyncMock()], ids=["sync-handler", "async-handler"]
56+
)
57+
async def test_publish_event(
58+
dask_client_multi: distributed.Client,
4759
job_id: str,
4860
task_owner: TaskOwner,
49-
monkeypatch: pytest.MonkeyPatch,
5061
handler: mock.Mock | mock.AsyncMock,
5162
):
5263
event_to_publish = TaskProgressEvent(
@@ -56,8 +67,8 @@ async def test_publish_event(
5667
task_owner=task_owner,
5768
)
5869

59-
# NOTE: only 1 handler per topic is allowed
60-
dask_client.subscribe_topic(TaskProgressEvent.topic_name(), handler)
70+
# NOTE: only 1 handler per topic is allowed at a time
71+
dask_client_multi.subscribe_topic(TaskProgressEvent.topic_name(), handler)
6172

6273
def _worker_task() -> int:
6374
with log_context(logging.INFO, "_worker_task"):
@@ -69,8 +80,8 @@ async def _() -> int:
6980

7081
return asyncio.run(_())
7182

72-
future = dask_client.submit(_worker_task)
73-
assert future.result(timeout=DASK_TESTING_TIMEOUT_S) == 2
83+
future = dask_client_multi.submit(_worker_task)
84+
assert await maybe_await(future.result(timeout=DASK_TESTING_TIMEOUT_S)) == 2
7485

7586
for attempt in Retrying(
7687
wait=wait_fixed(0.2),
@@ -79,7 +90,9 @@ async def _() -> int:
7990
retry=retry_if_exception_type(AssertionError),
8091
):
8192
with attempt:
82-
events = dask_client.get_events(TaskProgressEvent.topic_name())
93+
events = await maybe_await(
94+
dask_client_multi.get_events(TaskProgressEvent.topic_name())
95+
)
8396
assert events is not None, "No events received"
8497
assert isinstance(events, tuple)
8598

@@ -92,114 +105,20 @@ async def _() -> int:
92105
assert received_task_log_event == event_to_publish
93106

94107

95-
async def test_publish_event_async(
96-
async_dask_client: distributed.Client, job_id: str, task_owner: TaskOwner
97-
):
98-
event_to_publish = TaskProgressEvent(
99-
job_id=job_id,
100-
msg="the log",
101-
progress=2,
102-
task_owner=task_owner,
103-
)
104-
105-
async def handler(event: tuple) -> None:
106-
print("received event", event)
107-
assert isinstance(event, tuple)
108-
received_task_log_event = TaskProgressEvent.model_validate_json(event[1])
109-
assert received_task_log_event == event_to_publish
110-
111-
async_dask_client.subscribe_topic(TaskProgressEvent.topic_name(), handler)
112-
113-
await publish_event(async_dask_client, event=event_to_publish)
114-
115-
async for attempt in AsyncRetrying(
116-
wait=wait_fixed(0.2), stop=stop_after_delay(15), reraise=True
117-
):
118-
with attempt:
119-
events = await async_dask_client.get_events(TaskProgressEvent.topic_name())
120-
assert events is not None
121-
122-
assert isinstance(events, tuple)
123-
assert len(events) == 1
124-
assert isinstance(events[0], tuple)
125-
received_task_log_event = TaskProgressEvent.model_validate_json(events[0][1])
126-
assert received_task_log_event == event_to_publish
127-
128-
129-
@pytest.fixture
130-
async def asyncio_task() -> AsyncIterator[Callable[[Coroutine], asyncio.Task]]:
131-
created_tasks = []
132-
133-
def _creator(coro: Coroutine) -> asyncio.Task:
134-
task = asyncio.create_task(coro, name="pytest_asyncio_task")
135-
created_tasks.append(task)
136-
return task
137-
138-
yield _creator
139-
for task in created_tasks:
140-
task.cancel()
141-
142-
await asyncio.gather(*created_tasks, return_exceptions=True)
143-
144-
145-
async def test_publish_event_async_using_task(
146-
async_dask_client: distributed.Client,
147-
asyncio_task: Callable[[Coroutine], asyncio.Task],
148-
job_id: str,
149-
task_owner: TaskOwner,
150-
):
151-
NUMBER_OF_MESSAGES = 1000
152-
received_messages = []
153-
154-
async def _consumer(event: tuple) -> None:
155-
print("received event", event)
156-
assert isinstance(event, tuple)
157-
received_messages.append(event)
158-
159-
async_dask_client.subscribe_topic(TaskProgressEvent.topic_name(), _consumer)
160-
await asyncio.sleep(0)
161-
162-
async def _dask_publisher_task(async_dask_client: distributed.Client) -> None:
163-
print("--> starting publisher task")
164-
for _ in range(NUMBER_OF_MESSAGES):
165-
event_to_publish = TaskProgressEvent(
166-
job_id=job_id,
167-
progress=0.5,
168-
task_owner=task_owner,
169-
)
170-
await publish_event(async_dask_client, event=event_to_publish)
171-
print("<-- finished publisher task")
172-
173-
publisher_task = asyncio_task(_dask_publisher_task(async_dask_client))
174-
assert publisher_task
175-
176-
async for attempt in AsyncRetrying(
177-
retry=retry_if_exception_type(AssertionError),
178-
stop=stop_after_delay(DASK_TESTING_TIMEOUT_S),
179-
wait=wait_fixed(0.01),
180-
reraise=True,
181-
):
182-
with attempt:
183-
print(
184-
f"checking number of received messages...currently {len(received_messages)}"
185-
)
186-
assert len(received_messages) == NUMBER_OF_MESSAGES
187-
print("all expected messages received")
188-
189-
190-
def _wait_for_task_to_start() -> None:
191-
start_event = distributed.Event(DASK_TASK_STARTED_EVENT)
108+
def _wait_for_task_to_start(dask_client: distributed.Client) -> None:
109+
start_event = distributed.Event(DASK_TASK_STARTED_EVENT, dask_client)
192110
start_event.wait(timeout=DASK_TESTING_TIMEOUT_S)
193111

194112

195-
def _notify_task_is_started_and_ready() -> None:
196-
start_event = distributed.Event(DASK_TASK_STARTED_EVENT)
113+
def _notify_task_is_started_and_ready(dask_client: distributed.Client) -> None:
114+
start_event = distributed.Event(DASK_TASK_STARTED_EVENT, dask_client)
197115
start_event.set()
198116

199117

200118
def _some_long_running_task() -> int:
201119
assert is_current_task_aborted() is False
202-
_notify_task_is_started_and_ready()
120+
dask_client = distributed.get_worker().client
121+
_notify_task_is_started_and_ready(dask_client)
203122

204123
for i in range(300):
205124
print("running iteration", i)
@@ -217,7 +136,7 @@ def test_task_is_aborted(dask_client: distributed.Client):
217136
not work in distributed mode where an Event is necessary."""
218137
# NOTE: this works because the cluster is in the same machine
219138
future = dask_client.submit(_some_long_running_task)
220-
_wait_for_task_to_start()
139+
_wait_for_task_to_start(dask_client)
221140
future.cancel()
222141
assert future.cancelled()
223142
with pytest.raises(concurrent.futures.CancelledError):
@@ -227,7 +146,7 @@ def test_task_is_aborted(dask_client: distributed.Client):
227146
def test_task_is_aborted_using_event(dask_client: distributed.Client):
228147
job_id = "myfake_job_id"
229148
future = dask_client.submit(_some_long_running_task, key=job_id)
230-
_wait_for_task_to_start()
149+
_wait_for_task_to_start(dask_client)
231150

232151
dask_event = distributed.Event(TaskCancelEventName.format(job_id))
233152
dask_event.set()
@@ -244,7 +163,8 @@ def _some_long_running_task_with_monitoring(task_owner: TaskOwner) -> int:
244163

245164
async def _long_running_task_async() -> int:
246165
task_publishers = TaskPublisher(task_owner=task_owner)
247-
_notify_task_is_started_and_ready()
166+
worker = distributed.get_worker()
167+
_notify_task_is_started_and_ready(worker.client)
248168
current_task = asyncio.current_task()
249169
assert current_task
250170
async with monitor_task_abortion(
@@ -270,7 +190,7 @@ def test_monitor_task_abortion(
270190
future = dask_client.submit(
271191
_some_long_running_task_with_monitoring, task_owner=task_owner, key=job_id
272192
)
273-
_wait_for_task_to_start()
193+
_wait_for_task_to_start(dask_client)
274194
# trigger cancellation
275195
dask_event = distributed.Event(TaskCancelEventName.format(job_id))
276196
dask_event.set()

0 commit comments

Comments
 (0)