|
4 | 4 | from uuid import uuid4 |
5 | 5 |
|
6 | 6 | from celery import Celery # type: ignore[import-untyped] |
| 7 | +from celery.exceptions import CeleryError # type: ignore[import-untyped] |
7 | 8 | from common_library.async_tools import make_async |
8 | 9 | from models_library.progress_bar import ProgressReport |
9 | 10 | from servicelib.celery.models import ( |
|
21 | 22 | from servicelib.logging_utils import log_context |
22 | 23 | from settings_library.celery import CelerySettings |
23 | 24 |
|
24 | | -from .errors import TaskNotFoundError |
| 25 | +from .errors import TaskNotFoundError, TaskSubmissionError |
25 | 26 |
|
26 | 27 | _logger = logging.getLogger(__name__) |
27 | 28 |
|
@@ -50,21 +51,38 @@ async def submit_task( |
50 | 51 | ): |
51 | 52 | task_uuid = uuid4() |
52 | 53 | task_id = task_filter.create_task_id(task_uuid=task_uuid) |
53 | | - self._celery_app.send_task( |
54 | | - task_metadata.name, |
55 | | - task_id=task_id, |
56 | | - kwargs={"task_id": task_id} | task_params, |
57 | | - queue=task_metadata.queue.value, |
58 | | - ) |
59 | 54 |
|
60 | 55 | expiry = ( |
61 | 56 | self._celery_settings.CELERY_EPHEMERAL_RESULT_EXPIRES |
62 | 57 | if task_metadata.ephemeral |
63 | 58 | else self._celery_settings.CELERY_RESULT_EXPIRES |
64 | 59 | ) |
65 | | - await self._task_info_store.create_task( |
66 | | - task_id, task_metadata, expiry=expiry |
67 | | - ) |
| 60 | + |
| 61 | + try: |
| 62 | + await self._task_info_store.create_task( |
| 63 | + task_id, task_metadata, expiry=expiry |
| 64 | + ) |
| 65 | + self._celery_app.send_task( |
| 66 | + task_metadata.name, |
| 67 | + task_id=task_id, |
| 68 | + kwargs={"task_id": task_id} | task_params, |
| 69 | + queue=task_metadata.queue.value, |
| 70 | + ) |
| 71 | + except CeleryError as exc: |
| 72 | + try: |
| 73 | + await self._task_info_store.remove_task(task_id) |
| 74 | + except CeleryError: |
| 75 | + _logger.warning( |
| 76 | + "Unable to cleanup task '%s' during error handling", |
| 77 | + task_id, |
| 78 | + exc_info=True, |
| 79 | + ) |
| 80 | + raise TaskSubmissionError( |
| 81 | + task_name=task_metadata.name, |
| 82 | + task_id=task_id, |
| 83 | + task_params=task_params, |
| 84 | + ) from exc |
| 85 | + |
68 | 86 | return task_uuid |
69 | 87 |
|
70 | 88 | async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> None: |
|
0 commit comments