1111from models_library .api_schemas_long_running_tasks .base import TaskProgress
1212from pydantic import PositiveFloat
1313from settings_library .redis import RedisDatabase , RedisSettings
14+ from tenacity import (
15+ AsyncRetrying ,
16+ TryAgain ,
17+ retry_if_exception_type ,
18+ stop_after_delay ,
19+ wait_fixed ,
20+ )
1421
1522from ..background_task import create_periodic_task
1623from ..redis import RedisClientSDK , exclusive
@@ -163,12 +170,22 @@ async def teardown(self) -> None:
163170 tracked_task .task_id , tracked_task .task_context , reraise_errors = False
164171 )
165172
173+ for task in self ._created_tasks .values ():
174+ _logger .warning (
175+ "Task %s was not completed before shutdown, cancelling it" ,
176+ task .get_name (),
177+ )
178+ await cancel_wait_task (task )
179+
166180 if self ._stale_tasks_monitor_task :
167181 await cancel_wait_task (self ._stale_tasks_monitor_task )
168182
169183 if self ._cancelled_tasks_removal_task :
170184 await cancel_wait_task (self ._cancelled_tasks_removal_task )
171185
186+ if self ._status_update_worker_task :
187+ await cancel_wait_task (self ._status_update_worker_task )
188+
172189 if self .redis_client_sdk is not None :
173190 await self .redis_client_sdk .shutdown ()
174191
@@ -281,27 +298,6 @@ async def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBa
281298 if task .task_context == with_task_context
282299 ]
283300
284- async def _add_task (
285- self ,
286- task : asyncio .Task ,
287- task_progress : TaskProgress ,
288- task_context : TaskContext ,
289- task_id : TaskId ,
290- * ,
291- fire_and_forget : bool ,
292- ) -> TaskData :
293-
294- task_data = TaskData (
295- task_id = task_id ,
296- task_progress = task_progress ,
297- task_context = task_context ,
298- fire_and_forget = fire_and_forget ,
299- )
300- await self ._tasks_data .set_task_data (task_id , task_data )
301- self ._created_tasks [task_id ] = task
302-
303- return task_data
304-
305301 async def _get_tracked_task (
306302 self , task_id : TaskId , with_task_context : TaskContext
307303 ) -> TaskData :
@@ -393,7 +389,18 @@ async def remove_task(
393389 await self ._tasks_data .delete_task_data (task_id )
394390 self ._created_tasks .pop (tracked_task .task_id , None )
395391
396- # TODO: wait for removal to becompleted here
392+ # wait for task to be completed
393+ async for attempt in AsyncRetrying (
394+ wait = wait_fixed (0.1 ),
395+ stop = stop_after_delay (10 ),
396+ retry = retry_if_exception_type (TryAgain ),
397+ ):
398+ with attempt :
399+ try :
400+ await self ._get_tracked_task (task_id , with_task_context )
401+ raise TryAgain
402+ except TaskNotFoundError :
403+ pass
397404
398405 def _get_task_id (self , task_name : str , * , is_unique : bool ) -> TaskId :
399406 unique_part = "unique" if is_unique else f"{ uuid4 ()} "
@@ -405,9 +412,18 @@ async def _update_progress(
405412 task_context : TaskContext ,
406413 task_progress : TaskProgress ,
407414 ) -> None :
408- tracked_data = await self ._get_tracked_task (task_id , task_context )
409- tracked_data .task_progress = task_progress
410- await self ._tasks_data .set_task_data (task_id = task_id , value = tracked_data )
415+ # NOTE: avoids errors while updating progress, since the task could have been
416+ # cancelled and it's data removed
417+ try :
418+ tracked_data = await self ._get_tracked_task (task_id , task_context )
419+ tracked_data .task_progress = task_progress
420+ await self ._tasks_data .set_task_data (task_id = task_id , value = tracked_data )
421+ except TaskNotFoundError :
422+ _logger .debug (
423+ "Task '%s' not found while updating progress %s" ,
424+ task_id ,
425+ task_progress ,
426+ )
411427
412428 async def start_task (
413429 self ,
@@ -447,26 +463,25 @@ async def start_task(
447463 functools .partial (self ._update_progress , task_id , context_to_use )
448464 )
449465
450- # bind the task with progress 0 and 1
451- async def _progress_task ( progress : TaskProgress , handler : TaskProtocol ):
466+ async def _task_with_progress ( progress : TaskProgress , handler : TaskProtocol ):
467+ # bind the task with progress 0 and 1
452468 await progress .update (message = "starting" , percent = 0 )
453469 try :
454470 return await handler (progress , ** task_kwargs )
455471 finally :
456472 await progress .update (message = "finished" , percent = 1 )
457473
458- async_task = asyncio .create_task (
459- _progress_task (task_progress , task ), name = task_name
474+ self . _created_tasks [ task_id ] = asyncio .create_task (
475+ _task_with_progress (task_progress , task ), name = task_name
460476 )
461477
462- tracked_task = await self . _add_task (
463- task = async_task ,
478+ tracked_task = TaskData (
479+ task_id = task_id ,
464480 task_progress = task_progress ,
465481 task_context = context_to_use ,
466482 fire_and_forget = fire_and_forget ,
467- task_id = task_id ,
468483 )
469-
484+ await self . _tasks_data . set_task_data ( task_id , tracked_task )
470485 return tracked_task .task_id
471486
472487
0 commit comments