66
77import asyncio
88import concurrent .futures
9+ import logging
910import time
10- from collections .abc import AsyncIterator , Callable , Coroutine
1111from typing import Any
12+ from unittest import mock
1213
1314import distributed
1415import pytest
16+ from common_library .async_tools import maybe_await
1517from dask_task_models_library .container_tasks .errors import TaskCancelledError
1618from dask_task_models_library .container_tasks .events import TaskProgressEvent
1719from dask_task_models_library .container_tasks .io import TaskCancelEventName
1820from dask_task_models_library .container_tasks .protocol import TaskOwner
21+ from pytest_simcore .helpers .logging_tools import log_context
1922from simcore_service_dask_sidecar .utils .dask import (
2023 _DEFAULT_MAX_RESOURCES ,
2124 TaskPublisher ,
2427 monitor_task_abortion ,
2528 publish_event ,
2629)
27- from tenacity . asyncio import AsyncRetrying
30+ from tenacity import Retrying
2831from tenacity .retry import retry_if_exception_type
2932from tenacity .stop import stop_after_delay
3033from tenacity .wait import wait_fixed
3740]
3841
3942
40- def test_publish_event (
41- dask_client : distributed .Client , job_id : str , task_owner : TaskOwner
43+ @pytest .fixture (params = ["sync-dask-client" , "async-dask-client" ])
44+ def dask_client_multi (
45+ request : pytest .FixtureRequest ,
46+ dask_client : distributed .Client ,
47+ async_dask_client : distributed .Client ,
48+ ) -> distributed .Client :
49+ if request .param == "sync-dask-client" :
50+ return dask_client
51+ return async_dask_client
52+
53+
54+ @pytest .mark .parametrize (
55+ "handler" , [mock .Mock (), mock .AsyncMock ()], ids = ["sync-handler" , "async-handler" ]
56+ )
57+ async def test_publish_event (
58+ dask_client_multi : distributed .Client ,
59+ job_id : str ,
60+ task_owner : TaskOwner ,
61+ handler : mock .Mock | mock .AsyncMock ,
4262):
43- dask_pub = distributed .Pub ("some_topic" , client = dask_client )
44- dask_sub = distributed .Sub ("some_topic" , client = dask_client )
4563 event_to_publish = TaskProgressEvent (
4664 job_id = job_id ,
4765 msg = "the log" ,
4866 progress = 1 ,
4967 task_owner = task_owner ,
5068 )
51- publish_event (dask_pub = dask_pub , event = event_to_publish )
52-
53- # NOTE: this tests runs a sync dask client,
54- # and the CI seems to have sometimes difficulties having this run in a reasonable time
55- # hence the long time out
56- message = dask_sub .get (timeout = DASK_TESTING_TIMEOUT_S )
57- assert message is not None
58- assert isinstance (message , str )
59- received_task_log_event = TaskProgressEvent .model_validate_json (message )
60- assert received_task_log_event == event_to_publish
61-
62-
63- async def test_publish_event_async (
64- async_dask_client : distributed .Client , job_id : str , task_owner : TaskOwner
65- ):
66- dask_pub = distributed .Pub ("some_topic" , client = async_dask_client )
67- dask_sub = distributed .Sub ("some_topic" , client = async_dask_client )
68- event_to_publish = TaskProgressEvent (
69- job_id = job_id , msg = "the log" , progress = 2 , task_owner = task_owner
70- )
71- publish_event (dask_pub = dask_pub , event = event_to_publish )
72-
73- # NOTE: this tests runs a sync dask client,
74- # and the CI seems to have sometimes difficulties having this run in a reasonable time
75- # hence the long time out
76- message = dask_sub .get (timeout = DASK_TESTING_TIMEOUT_S )
77- assert isinstance (message , Coroutine )
78- message = await message
79- assert message is not None
80- received_task_log_event = TaskProgressEvent .model_validate_json (message )
81- assert received_task_log_event == event_to_publish
8269
70+ # NOTE: only 1 handler per topic is allowed at a time
71+ dask_client_multi .subscribe_topic (TaskProgressEvent .topic_name (), handler )
8372
84- @pytest .fixture
85- async def asyncio_task () -> AsyncIterator [Callable [[Coroutine ], asyncio .Task ]]:
86- created_tasks = []
73+ def _worker_task () -> int :
74+ with log_context (logging .INFO , "_worker_task" ):
8775
88- def _creator ( coro : Coroutine ) -> asyncio . Task :
89- task = asyncio . create_task ( coro , name = "pytest_asyncio_task" )
90- created_tasks . append ( task )
91- return task
76+ async def _ ( ) -> int :
77+ with log_context ( logging . INFO , "_worker_task_async" ):
78+ publish_event ( event_to_publish )
79+ return 2
9280
93- yield _creator
94- for task in created_tasks :
95- task .cancel ()
81+ return asyncio .run (_ ())
9682
97- await asyncio .gather (* created_tasks , return_exceptions = True )
83+ future = dask_client_multi .submit (_worker_task )
84+ assert await maybe_await (future .result (timeout = DASK_TESTING_TIMEOUT_S )) == 2
9885
99-
100- async def test_publish_event_async_using_task (
101- async_dask_client : distributed .Client ,
102- asyncio_task : Callable [[Coroutine ], asyncio .Task ],
103- job_id : str ,
104- task_owner : TaskOwner ,
105- ):
106- dask_pub = distributed .Pub ("some_topic" , client = async_dask_client )
107- dask_sub = distributed .Sub ("some_topic" , client = async_dask_client )
108- NUMBER_OF_MESSAGES = 1000
109- received_messages = []
110-
111- async def _dask_sub_consumer_task (sub : distributed .Sub ) -> None :
112- print ("--> starting consumer task" )
113- async for dask_event in sub :
114- print (f"received { dask_event } " )
115- received_messages .append (dask_event )
116- print ("<-- finished consumer task" )
117-
118- consumer_task = asyncio_task (_dask_sub_consumer_task (dask_sub ))
119- assert consumer_task
120-
121- async def _dask_publisher_task (pub : distributed .Pub ) -> None :
122- print ("--> starting publisher task" )
123- for _ in range (NUMBER_OF_MESSAGES ):
124- event_to_publish = TaskProgressEvent (
125- job_id = job_id ,
126- progress = 0.5 ,
127- task_owner = task_owner ,
128- )
129- publish_event (dask_pub = pub , event = event_to_publish )
130- print ("<-- finished publisher task" )
131-
132- publisher_task = asyncio_task (_dask_publisher_task (dask_pub ))
133- assert publisher_task
134-
135- async for attempt in AsyncRetrying (
136- retry = retry_if_exception_type (AssertionError ),
137- stop = stop_after_delay (DASK_TESTING_TIMEOUT_S ),
138- wait = wait_fixed (0.01 ),
86+ for attempt in Retrying (
87+ wait = wait_fixed (0.2 ),
88+ stop = stop_after_delay (15 ),
13989 reraise = True ,
90+ retry = retry_if_exception_type (AssertionError ),
14091 ):
14192 with attempt :
142- print (
143- f"checking number of received messages...currently { len ( received_messages ) } "
93+ events = await maybe_await (
94+ dask_client_multi . get_events ( TaskProgressEvent . topic_name ())
14495 )
145- assert len ( received_messages ) == NUMBER_OF_MESSAGES
146- print ( "all expected messages received" )
96+ assert events is not None , "No events received"
97+ assert isinstance ( events , tuple )
14798
99+ handler .assert_called_with (events [- 1 ])
148100
149- def _wait_for_task_to_start () -> None :
150- start_event = distributed .Event (DASK_TASK_STARTED_EVENT )
101+ assert isinstance (events , tuple )
102+ assert len (events ) == 1
103+ assert isinstance (events [0 ], tuple )
104+ received_task_log_event = TaskProgressEvent .model_validate_json (events [0 ][1 ])
105+ assert received_task_log_event == event_to_publish
106+
107+
108+ def _wait_for_task_to_start (dask_client : distributed .Client ) -> None :
109+ start_event = distributed .Event (DASK_TASK_STARTED_EVENT , dask_client )
151110 start_event .wait (timeout = DASK_TESTING_TIMEOUT_S )
152111
153112
154- def _notify_task_is_started_and_ready () -> None :
155- start_event = distributed .Event (DASK_TASK_STARTED_EVENT )
113+ def _notify_task_is_started_and_ready (dask_client : distributed . Client ) -> None :
114+ start_event = distributed .Event (DASK_TASK_STARTED_EVENT , dask_client )
156115 start_event .set ()
157116
158117
159118def _some_long_running_task () -> int :
160119 assert is_current_task_aborted () is False
161- _notify_task_is_started_and_ready ()
120+ dask_client = distributed .get_worker ().client
121+ _notify_task_is_started_and_ready (dask_client )
162122
163123 for i in range (300 ):
164124 print ("running iteration" , i )
@@ -176,7 +136,7 @@ def test_task_is_aborted(dask_client: distributed.Client):
176136 not work in distributed mode where an Event is necessary."""
177137 # NOTE: this works because the cluster is in the same machine
178138 future = dask_client .submit (_some_long_running_task )
179- _wait_for_task_to_start ()
139+ _wait_for_task_to_start (dask_client )
180140 future .cancel ()
181141 assert future .cancelled ()
182142 with pytest .raises (concurrent .futures .CancelledError ):
@@ -186,7 +146,7 @@ def test_task_is_aborted(dask_client: distributed.Client):
186146def test_task_is_aborted_using_event (dask_client : distributed .Client ):
187147 job_id = "myfake_job_id"
188148 future = dask_client .submit (_some_long_running_task , key = job_id )
189- _wait_for_task_to_start ()
149+ _wait_for_task_to_start (dask_client )
190150
191151 dask_event = distributed .Event (TaskCancelEventName .format (job_id ))
192152 dask_event .set ()
@@ -203,7 +163,8 @@ def _some_long_running_task_with_monitoring(task_owner: TaskOwner) -> int:
203163
204164 async def _long_running_task_async () -> int :
205165 task_publishers = TaskPublisher (task_owner = task_owner )
206- _notify_task_is_started_and_ready ()
166+ worker = distributed .get_worker ()
167+ _notify_task_is_started_and_ready (worker .client )
207168 current_task = asyncio .current_task ()
208169 assert current_task
209170 async with monitor_task_abortion (
@@ -229,7 +190,7 @@ def test_monitor_task_abortion(
229190 future = dask_client .submit (
230191 _some_long_running_task_with_monitoring , task_owner = task_owner , key = job_id
231192 )
232- _wait_for_task_to_start ()
193+ _wait_for_task_to_start (dask_client )
233194 # trigger cancellation
234195 dask_event = distributed .Event (TaskCancelEventName .format (job_id ))
235196 dask_event .set ()
0 commit comments