Skip to content

Commit 77ac520

Browse files
committed
add test for exception propagation from celery
1 parent 5824f2b commit 77ac520

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

services/api-server/src/simcore_service_api_server/exceptions/backend_errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class BaseBackEndError(ApiServerBaseError):
88
"""status_code: the default return status which will be returned to the client calling the
99
api-server (in case this exception is raised)"""
1010

11+
msg_template = "The api-server encountered an error when contacting the backend"
1112
status_code = status.HTTP_502_BAD_GATEWAY
1213

1314
@classmethod

services/api-server/tests/unit/celery/test_functions.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88
from celery_library.types import register_pydantic_types
99
from faker import Faker
1010
from fastapi import FastAPI, status
11-
from httpx import AsyncClient, BasicAuth
11+
from httpx import AsyncClient, BasicAuth, HTTPStatusError
1212
from models_library.api_schemas_long_running_tasks.tasks import (
1313
TaskGet,
1414
TaskResult,
1515
TaskStatus,
1616
)
17+
from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter
1718
from models_library.functions import (
1819
FunctionClass,
1920
FunctionID,
@@ -24,17 +25,22 @@
2425
RegisteredProjectFunctionJob,
2526
)
2627
from models_library.projects import ProjectID
27-
from servicelib.celery.models import TaskID
28+
from servicelib.celery.models import TaskFilter, TaskID, TaskMetadata
2829
from servicelib.common_headers import (
2930
X_SIMCORE_PARENT_NODE_ID,
3031
X_SIMCORE_PARENT_PROJECT_UUID,
3132
)
3233
from simcore_service_api_server._meta import API_VTAG
3334
from simcore_service_api_server.api.dependencies.authentication import Identity
35+
from simcore_service_api_server.api.dependencies.celery import (
36+
ASYNC_JOB_CLIENT_NAME,
37+
get_task_manager,
38+
)
3439
from simcore_service_api_server.api.routes.functions_routes import get_function
3540
from simcore_service_api_server.celery._worker_tasks._functions_tasks import (
3641
run_function as run_function_task,
3742
)
43+
from simcore_service_api_server.exceptions.backend_errors import BaseBackEndError
3844
from simcore_service_api_server.models.api_resources import JobLinks
3945
from simcore_service_api_server.models.schemas.jobs import (
4046
JobPricingSpecification,
@@ -151,3 +157,53 @@ async def test_with_fake_run_function(
151157
# Poll until task completion and get result
152158
result = await poll_task_until_done(client, auth, task.task_id)
153159
RegisteredProjectFunctionJob.model_validate(result.result)
160+
161+
162+
def _register_exception_task(exception: Exception) -> Callable[[Celery], None]:
163+
164+
async def exception_task(
165+
task: Task,
166+
task_id: TaskID,
167+
):
168+
raise exception
169+
170+
def _(celery_app: Celery) -> None:
171+
register_task(celery_app, exception_task)
172+
173+
return _
174+
175+
176+
@pytest.mark.parametrize(
177+
"register_celery_tasks",
178+
[
179+
_register_exception_task(ValueError("Test error")),
180+
_register_exception_task(Exception("Test error")),
181+
_register_exception_task(BaseBackEndError()),
182+
],
183+
)
184+
@pytest.mark.parametrize("add_worker_tasks", [False])
185+
async def test_celery_error_propagation(
186+
app: FastAPI,
187+
client: AsyncClient,
188+
auth: BasicAuth,
189+
with_api_server_celery_worker: TestWorkController,
190+
):
191+
192+
user_identity = Identity(
193+
user_id=_faker.pyint(), product_name=_faker.word(), email=_faker.email()
194+
)
195+
job_filter = AsyncJobFilter(
196+
user_id=user_identity.user_id,
197+
product_name=user_identity.product_name,
198+
client_name=ASYNC_JOB_CLIENT_NAME,
199+
)
200+
task_manager = get_task_manager(app=app)
201+
task_uuid = await task_manager.submit_task(
202+
task_metadata=TaskMetadata(name="exception_task"),
203+
task_filter=TaskFilter.model_validate(job_filter.model_dump()),
204+
)
205+
206+
with pytest.raises(HTTPStatusError) as exc_info:
207+
await poll_task_until_done(client, auth, f"{task_uuid}")
208+
209+
assert exc_info.value.response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR

0 commit comments

Comments
 (0)