|
2 | 2 | # pylint: disable=unused-argument |
3 | 3 |
|
4 | 4 |
|
| 5 | +from collections.abc import Callable |
| 6 | +from typing import Literal |
| 7 | + |
5 | 8 | import pytest |
| 9 | +from celery.exceptions import CeleryError |
6 | 10 | from faker import Faker |
7 | 11 | from fastapi import status |
8 | 12 | from httpx import AsyncClient, BasicAuth |
@@ -37,28 +41,43 @@ def _get_task_manager(app): |
37 | 41 | return mock_task_manager_object |
38 | 42 |
|
39 | 43 |
|
40 | | -@pytest.mark.parametrize( |
41 | | - "expected_status_code", |
42 | | - [status.HTTP_200_OK], |
43 | | -) |
44 | 44 | async def test_list_celery_tasks( |
45 | 45 | mock_task_manager: MockType, |
46 | 46 | client: AsyncClient, |
47 | 47 | auth: BasicAuth, |
48 | | - expected_status_code: int, |
49 | 48 | ): |
50 | 49 |
|
51 | 50 | response = await client.get("/v0/tasks", auth=auth) |
52 | | - assert response.status_code == expected_status_code |
| 51 | + assert response.status_code == status.HTTP_200_OK |
53 | 52 |
|
54 | | - if response.status_code == status.HTTP_200_OK: |
55 | | - result = ApiServerEnvelope[list[TaskGet]].model_validate_json(response.text) |
56 | | - assert len(result.data) > 0 |
57 | | - assert all(isinstance(task, TaskGet) for task in result.data) |
58 | | - task = result.data[0] |
59 | | - assert task.abort_href == f"/v0/tasks/{task.task_id}:cancel" |
60 | | - assert task.result_href == f"/v0/tasks/{task.task_id}/result" |
61 | | - assert task.status_href == f"/v0/tasks/{task.task_id}" |
| 53 | + result = ApiServerEnvelope[list[TaskGet]].model_validate_json(response.text) |
| 54 | + assert len(result.data) > 0 |
| 55 | + assert all(isinstance(task, TaskGet) for task in result.data) |
| 56 | + task = result.data[0] |
| 57 | + assert task.abort_href == f"/v0/tasks/{task.task_id}:cancel" |
| 58 | + assert task.result_href == f"/v0/tasks/{task.task_id}/result" |
| 59 | + assert task.status_href == f"/v0/tasks/{task.task_id}" |
| 60 | + |
| 61 | + |
| 62 | +@pytest.mark.parametrize( |
| 63 | + "method, url, celery_exception, expected_status_code", |
| 64 | + [ |
| 65 | + ("GET", "/v0/tasks", CeleryError(), status.HTTP_500_INTERNAL_SERVER_ERROR), |
| 66 | + ], |
| 67 | +) |
| 68 | +async def test_celery_tasks_error_propagation( |
| 69 | + mock_task_manager_raising_factory: Callable[[Exception], None], |
| 70 | + client: AsyncClient, |
| 71 | + auth: BasicAuth, |
| 72 | + method: Literal["GET", "POST"], |
| 73 | + url: str, |
| 74 | + celery_exception: Exception, |
| 75 | + expected_status_code: int, |
| 76 | +): |
| 77 | + mock_task_manager_raising_factory(celery_exception) |
| 78 | + |
| 79 | + response = await client.request(method=method, url=url, auth=auth) |
| 80 | + assert response.status_code == expected_status_code |
62 | 81 |
|
63 | 82 |
|
64 | 83 | @pytest.mark.parametrize( |
|
0 commit comments