Skip to content

Commit 91c4651

Browse files
author
Andrei Neagu
committed
avoid CI fro getting stuck
1 parent d07cfda commit 91c4651

File tree

3 files changed

+59
-40
lines changed

3 files changed

+59
-40
lines changed

packages/models-library/src/models_library/api_schemas_long_running_tasks/base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,14 @@ async def update(
5050
_logger.debug("Progress update: %s", f"{self}")
5151

5252
if self._update_callback is not None:
53-
await self._update_callback(self)
54-
else:
55-
_logger.warning(
56-
"No update callback set for TaskProgress %s, progress will not be propagated",
57-
self.task_id,
58-
)
53+
try:
54+
await self._update_callback(self)
55+
except Exception as exc: # pylint: disable=broad-exception-caught
56+
_logger.warning(
57+
"Error while calling progress update callback: %s",
58+
exc,
59+
stack_info=True,
60+
)
5961

6062
@classmethod
6163
def create(cls, task_id: TaskId | None = None) -> "TaskProgress":

packages/service-library/src/servicelib/long_running_tasks/task.py

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
from models_library.api_schemas_long_running_tasks.base import TaskProgress
1212
from pydantic import PositiveFloat
1313
from settings_library.redis import RedisDatabase, RedisSettings
14+
from tenacity import (
15+
AsyncRetrying,
16+
TryAgain,
17+
retry_if_exception_type,
18+
stop_after_delay,
19+
wait_fixed,
20+
)
1421

1522
from ..background_task import create_periodic_task
1623
from ..redis import RedisClientSDK, exclusive
@@ -163,12 +170,22 @@ async def teardown(self) -> None:
163170
tracked_task.task_id, tracked_task.task_context, reraise_errors=False
164171
)
165172

173+
for task in self._created_tasks.values():
174+
_logger.warning(
175+
"Task %s was not completed before shutdown, cancelling it",
176+
task.get_name(),
177+
)
178+
await cancel_wait_task(task)
179+
166180
if self._stale_tasks_monitor_task:
167181
await cancel_wait_task(self._stale_tasks_monitor_task)
168182

169183
if self._cancelled_tasks_removal_task:
170184
await cancel_wait_task(self._cancelled_tasks_removal_task)
171185

186+
if self._status_update_worker_task:
187+
await cancel_wait_task(self._status_update_worker_task)
188+
172189
if self.redis_client_sdk is not None:
173190
await self.redis_client_sdk.shutdown()
174191

@@ -281,27 +298,6 @@ async def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBa
281298
if task.task_context == with_task_context
282299
]
283300

284-
async def _add_task(
285-
self,
286-
task: asyncio.Task,
287-
task_progress: TaskProgress,
288-
task_context: TaskContext,
289-
task_id: TaskId,
290-
*,
291-
fire_and_forget: bool,
292-
) -> TaskData:
293-
294-
task_data = TaskData(
295-
task_id=task_id,
296-
task_progress=task_progress,
297-
task_context=task_context,
298-
fire_and_forget=fire_and_forget,
299-
)
300-
await self._tasks_data.set_task_data(task_id, task_data)
301-
self._created_tasks[task_id] = task
302-
303-
return task_data
304-
305301
async def _get_tracked_task(
306302
self, task_id: TaskId, with_task_context: TaskContext
307303
) -> TaskData:
@@ -393,7 +389,18 @@ async def remove_task(
393389
await self._tasks_data.delete_task_data(task_id)
394390
self._created_tasks.pop(tracked_task.task_id, None)
395391

396-
# TODO: wait for removal to becompleted here
392+
# wait for task to be completed
393+
async for attempt in AsyncRetrying(
394+
wait=wait_fixed(0.1),
395+
stop=stop_after_delay(10),
396+
retry=retry_if_exception_type(TryAgain),
397+
):
398+
with attempt:
399+
try:
400+
await self._get_tracked_task(task_id, with_task_context)
401+
raise TryAgain
402+
except TaskNotFoundError:
403+
pass
397404

398405
def _get_task_id(self, task_name: str, *, is_unique: bool) -> TaskId:
399406
unique_part = "unique" if is_unique else f"{uuid4()}"
@@ -405,9 +412,18 @@ async def _update_progress(
405412
task_context: TaskContext,
406413
task_progress: TaskProgress,
407414
) -> None:
408-
tracked_data = await self._get_tracked_task(task_id, task_context)
409-
tracked_data.task_progress = task_progress
410-
await self._tasks_data.set_task_data(task_id=task_id, value=tracked_data)
415+
# NOTE: avoids errors while updating progress, since the task could have been
416+
# cancelled and it's data removed
417+
try:
418+
tracked_data = await self._get_tracked_task(task_id, task_context)
419+
tracked_data.task_progress = task_progress
420+
await self._tasks_data.set_task_data(task_id=task_id, value=tracked_data)
421+
except TaskNotFoundError:
422+
_logger.debug(
423+
"Task '%s' not found while updating progress %s",
424+
task_id,
425+
task_progress,
426+
)
411427

412428
async def start_task(
413429
self,
@@ -447,26 +463,25 @@ async def start_task(
447463
functools.partial(self._update_progress, task_id, context_to_use)
448464
)
449465

450-
# bind the task with progress 0 and 1
451-
async def _progress_task(progress: TaskProgress, handler: TaskProtocol):
466+
async def _task_with_progress(progress: TaskProgress, handler: TaskProtocol):
467+
# bind the task with progress 0 and 1
452468
await progress.update(message="starting", percent=0)
453469
try:
454470
return await handler(progress, **task_kwargs)
455471
finally:
456472
await progress.update(message="finished", percent=1)
457473

458-
async_task = asyncio.create_task(
459-
_progress_task(task_progress, task), name=task_name
474+
self._created_tasks[task_id] = asyncio.create_task(
475+
_task_with_progress(task_progress, task), name=task_name
460476
)
461477

462-
tracked_task = await self._add_task(
463-
task=async_task,
478+
tracked_task = TaskData(
479+
task_id=task_id,
464480
task_progress=task_progress,
465481
task_context=context_to_use,
466482
fire_and_forget=fire_and_forget,
467-
task_id=task_id,
468483
)
469-
484+
await self._tasks_data.set_task_data(task_id, tracked_task)
470485
return tracked_task.task_id
471486

472487

packages/service-library/tests/long_running_tasks/test_long_running_tasks_task.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import asyncio
88
import urllib.parse
99
from collections.abc import AsyncIterator, Awaitable, Callable
10+
from contextlib import suppress
1011
from datetime import datetime, timedelta
1112
from typing import Any, Final
1213

@@ -96,7 +97,8 @@ async def _(redis_settings: RedisSettings) -> TasksManager:
9697
yield _
9798

9899
for manager in managers:
99-
await manager.teardown()
100+
with suppress(Exception): # avoids tests form hanging on test teardown
101+
await manager.teardown()
100102

101103

102104
@pytest.fixture

0 commit comments

Comments
 (0)