88import concurrent .futures
99import logging
1010import time
11- from collections .abc import AsyncIterator , Callable , Coroutine
1211from typing import Any
1312from unittest import mock
1413
1514import distributed
1615import pytest
16+ from common_library .async_tools import maybe_await
1717from dask_task_models_library .container_tasks .errors import TaskCancelledError
1818from dask_task_models_library .container_tasks .events import TaskProgressEvent
1919from dask_task_models_library .container_tasks .io import TaskCancelEventName
2828 publish_event ,
2929)
3030from tenacity import Retrying
31- from tenacity .asyncio import AsyncRetrying
3231from tenacity .retry import retry_if_exception_type
3332from tenacity .stop import stop_after_delay
3433from tenacity .wait import wait_fixed
4140]
4241
4342
44- @pytest .mark .parametrize ("handler" , [mock .Mock (), mock .AsyncMock ()])
45- async def test_publish_event (
43+ @pytest .fixture (params = ["sync-dask-client" , "async-dask-client" ])
44+ def dask_client_multi (
45+ request : pytest .FixtureRequest ,
4646 dask_client : distributed .Client ,
47+ async_dask_client : distributed .Client ,
48+ ) -> distributed .Client :
49+ if request .param == "sync-dask-client" :
50+ return dask_client
51+ return async_dask_client
52+
53+
54+ @pytest .mark .parametrize (
55+ "handler" , [mock .Mock (), mock .AsyncMock ()], ids = ["sync-handler" , "async-handler" ]
56+ )
57+ async def test_publish_event (
58+ dask_client_multi : distributed .Client ,
4759 job_id : str ,
4860 task_owner : TaskOwner ,
49- monkeypatch : pytest .MonkeyPatch ,
5061 handler : mock .Mock | mock .AsyncMock ,
5162):
5263 event_to_publish = TaskProgressEvent (
@@ -56,8 +67,8 @@ async def test_publish_event(
5667 task_owner = task_owner ,
5768 )
5869
59- # NOTE: only 1 handler per topic is allowed
60- dask_client .subscribe_topic (TaskProgressEvent .topic_name (), handler )
70+ # NOTE: only 1 handler per topic is allowed at a time
71+ dask_client_multi .subscribe_topic (TaskProgressEvent .topic_name (), handler )
6172
6273 def _worker_task () -> int :
6374 with log_context (logging .INFO , "_worker_task" ):
@@ -69,8 +80,8 @@ async def _() -> int:
6980
7081 return asyncio .run (_ ())
7182
72- future = dask_client .submit (_worker_task )
73- assert future .result (timeout = DASK_TESTING_TIMEOUT_S ) == 2
83+ future = dask_client_multi .submit (_worker_task )
84+ assert await maybe_await ( future .result (timeout = DASK_TESTING_TIMEOUT_S ) ) == 2
7485
7586 for attempt in Retrying (
7687 wait = wait_fixed (0.2 ),
@@ -79,7 +90,9 @@ async def _() -> int:
7990 retry = retry_if_exception_type (AssertionError ),
8091 ):
8192 with attempt :
82- events = dask_client .get_events (TaskProgressEvent .topic_name ())
93+ events = await maybe_await (
94+ dask_client_multi .get_events (TaskProgressEvent .topic_name ())
95+ )
8396 assert events is not None , "No events received"
8497 assert isinstance (events , tuple )
8598
@@ -92,114 +105,20 @@ async def _() -> int:
92105 assert received_task_log_event == event_to_publish
93106
94107
95- async def test_publish_event_async (
96- async_dask_client : distributed .Client , job_id : str , task_owner : TaskOwner
97- ):
98- event_to_publish = TaskProgressEvent (
99- job_id = job_id ,
100- msg = "the log" ,
101- progress = 2 ,
102- task_owner = task_owner ,
103- )
104-
105- async def handler (event : tuple ) -> None :
106- print ("received event" , event )
107- assert isinstance (event , tuple )
108- received_task_log_event = TaskProgressEvent .model_validate_json (event [1 ])
109- assert received_task_log_event == event_to_publish
110-
111- async_dask_client .subscribe_topic (TaskProgressEvent .topic_name (), handler )
112-
113- await publish_event (async_dask_client , event = event_to_publish )
114-
115- async for attempt in AsyncRetrying (
116- wait = wait_fixed (0.2 ), stop = stop_after_delay (15 ), reraise = True
117- ):
118- with attempt :
119- events = await async_dask_client .get_events (TaskProgressEvent .topic_name ())
120- assert events is not None
121-
122- assert isinstance (events , tuple )
123- assert len (events ) == 1
124- assert isinstance (events [0 ], tuple )
125- received_task_log_event = TaskProgressEvent .model_validate_json (events [0 ][1 ])
126- assert received_task_log_event == event_to_publish
127-
128-
129- @pytest .fixture
130- async def asyncio_task () -> AsyncIterator [Callable [[Coroutine ], asyncio .Task ]]:
131- created_tasks = []
132-
133- def _creator (coro : Coroutine ) -> asyncio .Task :
134- task = asyncio .create_task (coro , name = "pytest_asyncio_task" )
135- created_tasks .append (task )
136- return task
137-
138- yield _creator
139- for task in created_tasks :
140- task .cancel ()
141-
142- await asyncio .gather (* created_tasks , return_exceptions = True )
143-
144-
145- async def test_publish_event_async_using_task (
146- async_dask_client : distributed .Client ,
147- asyncio_task : Callable [[Coroutine ], asyncio .Task ],
148- job_id : str ,
149- task_owner : TaskOwner ,
150- ):
151- NUMBER_OF_MESSAGES = 1000
152- received_messages = []
153-
154- async def _consumer (event : tuple ) -> None :
155- print ("received event" , event )
156- assert isinstance (event , tuple )
157- received_messages .append (event )
158-
159- async_dask_client .subscribe_topic (TaskProgressEvent .topic_name (), _consumer )
160- await asyncio .sleep (0 )
161-
162- async def _dask_publisher_task (async_dask_client : distributed .Client ) -> None :
163- print ("--> starting publisher task" )
164- for _ in range (NUMBER_OF_MESSAGES ):
165- event_to_publish = TaskProgressEvent (
166- job_id = job_id ,
167- progress = 0.5 ,
168- task_owner = task_owner ,
169- )
170- await publish_event (async_dask_client , event = event_to_publish )
171- print ("<-- finished publisher task" )
172-
173- publisher_task = asyncio_task (_dask_publisher_task (async_dask_client ))
174- assert publisher_task
175-
176- async for attempt in AsyncRetrying (
177- retry = retry_if_exception_type (AssertionError ),
178- stop = stop_after_delay (DASK_TESTING_TIMEOUT_S ),
179- wait = wait_fixed (0.01 ),
180- reraise = True ,
181- ):
182- with attempt :
183- print (
184- f"checking number of received messages...currently { len (received_messages )} "
185- )
186- assert len (received_messages ) == NUMBER_OF_MESSAGES
187- print ("all expected messages received" )
188-
189-
190- def _wait_for_task_to_start () -> None :
191- start_event = distributed .Event (DASK_TASK_STARTED_EVENT )
108+ def _wait_for_task_to_start (dask_client : distributed .Client ) -> None :
109+ start_event = distributed .Event (DASK_TASK_STARTED_EVENT , dask_client )
192110 start_event .wait (timeout = DASK_TESTING_TIMEOUT_S )
193111
194112
195- def _notify_task_is_started_and_ready () -> None :
196- start_event = distributed .Event (DASK_TASK_STARTED_EVENT )
113+ def _notify_task_is_started_and_ready (dask_client : distributed . Client ) -> None :
114+ start_event = distributed .Event (DASK_TASK_STARTED_EVENT , dask_client )
197115 start_event .set ()
198116
199117
200118def _some_long_running_task () -> int :
201119 assert is_current_task_aborted () is False
202- _notify_task_is_started_and_ready ()
120+ dask_client = distributed .get_worker ().client
121+ _notify_task_is_started_and_ready (dask_client )
203122
204123 for i in range (300 ):
205124 print ("running iteration" , i )
@@ -217,7 +136,7 @@ def test_task_is_aborted(dask_client: distributed.Client):
217136 not work in distributed mode where an Event is necessary."""
218137 # NOTE: this works because the cluster is in the same machine
219138 future = dask_client .submit (_some_long_running_task )
220- _wait_for_task_to_start ()
139+ _wait_for_task_to_start (dask_client )
221140 future .cancel ()
222141 assert future .cancelled ()
223142 with pytest .raises (concurrent .futures .CancelledError ):
@@ -227,7 +146,7 @@ def test_task_is_aborted(dask_client: distributed.Client):
227146def test_task_is_aborted_using_event (dask_client : distributed .Client ):
228147 job_id = "myfake_job_id"
229148 future = dask_client .submit (_some_long_running_task , key = job_id )
230- _wait_for_task_to_start ()
149+ _wait_for_task_to_start (dask_client )
231150
232151 dask_event = distributed .Event (TaskCancelEventName .format (job_id ))
233152 dask_event .set ()
@@ -244,7 +163,8 @@ def _some_long_running_task_with_monitoring(task_owner: TaskOwner) -> int:
244163
245164 async def _long_running_task_async () -> int :
246165 task_publishers = TaskPublisher (task_owner = task_owner )
247- _notify_task_is_started_and_ready ()
166+ worker = distributed .get_worker ()
167+ _notify_task_is_started_and_ready (worker .client )
248168 current_task = asyncio .current_task ()
249169 assert current_task
250170 async with monitor_task_abortion (
@@ -270,7 +190,7 @@ def test_monitor_task_abortion(
270190 future = dask_client .submit (
271191 _some_long_running_task_with_monitoring , task_owner = task_owner , key = job_id
272192 )
273- _wait_for_task_to_start ()
193+ _wait_for_task_to_start (dask_client )
274194 # trigger cancellation
275195 dask_event = distributed .Event (TaskCancelEventName .format (job_id ))
276196 dask_event .set ()
0 commit comments