|
8 | 8 | from celery_library.types import register_pydantic_types |
9 | 9 | from faker import Faker |
10 | 10 | from fastapi import FastAPI, status |
11 | | -from httpx import AsyncClient, BasicAuth |
| 11 | +from httpx import AsyncClient, BasicAuth, HTTPStatusError |
12 | 12 | from models_library.api_schemas_long_running_tasks.tasks import ( |
13 | 13 | TaskGet, |
14 | 14 | TaskResult, |
15 | 15 | TaskStatus, |
16 | 16 | ) |
| 17 | +from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter |
17 | 18 | from models_library.functions import ( |
18 | 19 | FunctionClass, |
19 | 20 | FunctionID, |
|
24 | 25 | RegisteredProjectFunctionJob, |
25 | 26 | ) |
26 | 27 | from models_library.projects import ProjectID |
27 | | -from servicelib.celery.models import TaskID |
| 28 | +from servicelib.celery.models import TaskFilter, TaskID, TaskMetadata |
28 | 29 | from servicelib.common_headers import ( |
29 | 30 | X_SIMCORE_PARENT_NODE_ID, |
30 | 31 | X_SIMCORE_PARENT_PROJECT_UUID, |
31 | 32 | ) |
32 | 33 | from simcore_service_api_server._meta import API_VTAG |
33 | 34 | from simcore_service_api_server.api.dependencies.authentication import Identity |
| 35 | +from simcore_service_api_server.api.dependencies.celery import ( |
| 36 | + ASYNC_JOB_CLIENT_NAME, |
| 37 | + get_task_manager, |
| 38 | +) |
34 | 39 | from simcore_service_api_server.api.routes.functions_routes import get_function |
35 | 40 | from simcore_service_api_server.celery._worker_tasks._functions_tasks import ( |
36 | 41 | run_function as run_function_task, |
37 | 42 | ) |
| 43 | +from simcore_service_api_server.exceptions.backend_errors import BaseBackEndError |
38 | 44 | from simcore_service_api_server.models.api_resources import JobLinks |
39 | 45 | from simcore_service_api_server.models.schemas.jobs import ( |
40 | 46 | JobPricingSpecification, |
@@ -151,3 +157,53 @@ async def test_with_fake_run_function( |
151 | 157 | # Poll until task completion and get result |
152 | 158 | result = await poll_task_until_done(client, auth, task.task_id) |
153 | 159 | RegisteredProjectFunctionJob.model_validate(result.result) |
| 160 | + |
| 161 | + |
| 162 | +def _register_exception_task(exception: Exception) -> Callable[[Celery], None]: |
| 163 | + |
| 164 | + async def exception_task( |
| 165 | + task: Task, |
| 166 | + task_id: TaskID, |
| 167 | + ): |
| 168 | + raise exception |
| 169 | + |
| 170 | + def _(celery_app: Celery) -> None: |
| 171 | + register_task(celery_app, exception_task) |
| 172 | + |
| 173 | + return _ |
| 174 | + |
| 175 | + |
| 176 | +@pytest.mark.parametrize( |
| 177 | + "register_celery_tasks", |
| 178 | + [ |
| 179 | + _register_exception_task(ValueError("Test error")), |
| 180 | + _register_exception_task(Exception("Test error")), |
| 181 | + _register_exception_task(BaseBackEndError()), |
| 182 | + ], |
| 183 | +) |
| 184 | +@pytest.mark.parametrize("add_worker_tasks", [False]) |
| 185 | +async def test_celery_error_propagation( |
| 186 | + app: FastAPI, |
| 187 | + client: AsyncClient, |
| 188 | + auth: BasicAuth, |
| 189 | + with_api_server_celery_worker: TestWorkController, |
| 190 | +): |
| 191 | + |
| 192 | + user_identity = Identity( |
| 193 | + user_id=_faker.pyint(), product_name=_faker.word(), email=_faker.email() |
| 194 | + ) |
| 195 | + job_filter = AsyncJobFilter( |
| 196 | + user_id=user_identity.user_id, |
| 197 | + product_name=user_identity.product_name, |
| 198 | + client_name=ASYNC_JOB_CLIENT_NAME, |
| 199 | + ) |
| 200 | + task_manager = get_task_manager(app=app) |
| 201 | + task_uuid = await task_manager.submit_task( |
| 202 | + task_metadata=TaskMetadata(name="exception_task"), |
| 203 | + task_filter=TaskFilter.model_validate(job_filter.model_dump()), |
| 204 | + ) |
| 205 | + |
| 206 | + with pytest.raises(HTTPStatusError) as exc_info: |
| 207 | + await poll_task_until_done(client, auth, f"{task_uuid}") |
| 208 | + |
| 209 | + assert exc_info.value.response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR |
0 commit comments