Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ async def get_task_result(request: web.Request) -> web.Response | Any:
long_running_manager.tasks_manager,
long_running_manager.get_task_context(request),
path_params.task_id,
is_fasapi=False,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@
from tenacity.stop import stop_after_attempt
from tenacity.wait import wait_exponential

from ...long_running_tasks.errors import GenericClientError
from ...long_running_tasks.models import ClientConfiguration, TaskId, TaskStatus
from ...long_running_tasks.errors import GenericClientError, TaskClientResultError
from ...long_running_tasks.models import (
ClientConfiguration,
TaskId,
TaskResult,
TaskStatus,
)

_DEFAULT_HTTP_REQUESTS_TIMEOUT: Final[PositiveFloat] = 15

Expand Down Expand Up @@ -168,7 +173,10 @@ async def get_task_result(
body=result.text,
)

return result.json()
task_result = TaskResult.model_validate(result.json())
if task_result.error is not None:
raise TaskClientResultError(message=task_result.error)
return task_result.result

@retry_on_http_errors
async def cancel_and_delete_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,10 @@ async def get_task_result(
) -> TaskResult | Any:
assert request # nosec
return await http_endpoint_responses.get_task_result(
long_running_manager.tasks_manager, task_context=None, task_id=task_id
long_running_manager.tasks_manager,
task_context=None,
task_id=task_id,
is_fasapi=True,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,8 @@ class GenericClientError(BaseLongRunningError):
msg_template: str = (
"Unexpected error while '{action}' for '{task_id}': status={status} body={body}"
)


class TaskClientResultError(BaseLongRunningError):
code: str = "long_running_task.client.task_raised_error"
msg_template: str = "{message}"
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import logging
import traceback
from typing import Any

from .errors import TaskNotCompletedError, TaskNotFoundError
from .models import TaskBase, TaskId, TaskStatus
from .task import TaskContext, TasksManager, TrackedTask

_logger = logging.getLogger(__name__)


def list_tasks(
tasks_manager: TasksManager, task_context: TaskContext | None
Expand All @@ -22,17 +27,37 @@ def get_task_status(


async def get_task_result(
tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId
tasks_manager: TasksManager,
task_context: TaskContext | None,
task_id: TaskId,
*,
is_fasapi: bool,
) -> Any:
try:
return tasks_manager.get_task_result(
task_id=task_id, with_task_context=task_context
if is_fasapi:
task_result = tasks_manager.get_task_result_old(
task_id, with_task_context=task_context
)
else:
task_result = tasks_manager.get_task_result(
task_id, with_task_context=task_context
)
await tasks_manager.remove_task(
task_id, with_task_context=task_context, reraise_errors=False
)
finally:
# the task is always removed even if an error occurs
return task_result
except (TaskNotFoundError, TaskNotCompletedError):
raise
except Exception as exc:
# the task raised an exception
formatted_traceback = "".join(traceback.format_exception(exc))
_logger.info("Task '%s' raised an exception: %s", task_id, formatted_traceback)

# the task shall be removed in this case
await tasks_manager.remove_task(
task_id, with_task_context=task_context, reraise_errors=False
)
raise


async def remove_task(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from uuid import uuid4

from models_library.api_schemas_long_running_tasks.base import TaskProgress
from models_library.api_schemas_long_running_tasks.tasks import TaskResult
from pydantic import PositiveFloat
from servicelib.async_utils import cancel_wait_task
from servicelib.background_task import create_periodic_task
Expand Down Expand Up @@ -259,6 +260,38 @@ def get_task_result(
# the task was cancelled
raise TaskCancelledError(task_id=task_id) from exc

def get_task_result_old(
self, task_id: TaskId, with_task_context: TaskContext | None
) -> TaskResult:
"""
returns: the result of the task

raises TaskNotFoundError if the task cannot be found
"""
tracked_task = self._get_tracked_task(task_id, with_task_context)

if not tracked_task.task.done():
raise TaskNotCompletedError(task_id=task_id)

error: TaskExceptionError | TaskCancelledError
try:
exception = tracked_task.task.exception()
if exception is not None:
formatted_traceback = "\n".join(
traceback.format_tb(exception.__traceback__)
)
error = TaskExceptionError(
task_id=task_id, exception=exception, traceback=formatted_traceback
)
_logger.warning("Task %s finished with error: %s", task_id, f"{error}")
return TaskResult(result=None, error=f"{error}")
except asyncio.CancelledError:
error = TaskCancelledError(task_id=task_id)
_logger.warning("Task %s was cancelled", task_id)
return TaskResult(result=None, error=f"{error}")

return TaskResult(result=tracked_task.task.result(), error=None)

async def cancel_task(
self, task_id: TaskId, with_task_context: TaskContext | None
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from servicelib.fastapi.long_running_tasks.server import get_long_running_manager
from servicelib.fastapi.long_running_tasks.server import setup as setup_server
from servicelib.long_running_tasks.errors import (
TaskClientResultError,
TaskClientTimeoutError,
)
from servicelib.long_running_tasks.models import (
Expand Down Expand Up @@ -148,14 +149,16 @@ async def test_task_result_task_result_is_an_error(

url = TypeAdapter(AnyHttpUrl).validate_python("http://backgroud.testserver.io/")
client = Client(app=bg_task_app, async_client=async_client, base_url=url)
with pytest.raises(RuntimeError, match="I am failing as requested"):
with pytest.raises(TaskClientResultError) as exec_info:
async with periodic_task_result(
client,
task_id,
task_timeout=10,
status_poll_interval=TASK_SLEEP_INTERVAL / 3,
):
pass
assert f"{exec_info.value}".startswith(f"Task {task_id} finished with exception:")
assert "I am failing as requested" in f"{exec_info.value}"
await _assert_task_removed(async_client, task_id, router_prefix)


Expand Down
Loading