Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import functools
import logging
import warnings
from typing import Any, Awaitable, Callable, Final
from collections.abc import Awaitable, Callable
from typing import Any, Final

from fastapi import FastAPI, status
from httpx import AsyncClient, HTTPError
Expand All @@ -13,13 +14,8 @@
from tenacity.stop import stop_after_attempt
from tenacity.wait import wait_exponential

from ...long_running_tasks._errors import GenericClientError, TaskClientResultError
from ...long_running_tasks._models import (
ClientConfiguration,
TaskId,
TaskResult,
TaskStatus,
)
from ...long_running_tasks._errors import GenericClientError
from ...long_running_tasks._models import ClientConfiguration, TaskId, TaskStatus

DEFAULT_HTTP_REQUESTS_TIMEOUT: Final[PositiveFloat] = 15

Expand Down Expand Up @@ -85,7 +81,7 @@ def log_it(retry_state: RetryCallState) -> None:


def retry_on_http_errors(
request_func: Callable[..., Awaitable[Any]]
request_func: Callable[..., Awaitable[Any]],
) -> Callable[..., Awaitable[Any]]:
"""
Will retry the request on `httpx.HTTPError`.
Expand Down Expand Up @@ -173,10 +169,7 @@ async def get_task_result(
body=result.text,
)

task_result = TaskResult.model_validate(result.json())
if task_result.error is not None:
raise TaskClientResultError(message=task_result.error)
return task_result.result
return result.json()

@retry_on_http_errors
async def cancel_and_delete_task(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import asyncio
from asyncio.log import logger
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, AsyncIterator, Final
from typing import Any, Final

from pydantic import PositiveFloat

Expand Down Expand Up @@ -87,8 +88,7 @@ async def periodic_task_result(
- `status_poll_interval` optional: when waiting for a task to finish,
how frequent should the server be queried

raises: `TaskClientResultError` if the task finished with an error instead of
the expected result
raises: the original expcetion the task raised, if any
raises: `asyncio.TimeoutError` NOTE: the remote task will also be removed
"""

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Annotated, Any

from fastapi import APIRouter, Depends, Query, Request, status
from fastapi import APIRouter, Depends, Request, status

from ...long_running_tasks._errors import TaskNotCompletedError, TaskNotFoundError
from ...long_running_tasks._models import TaskGet, TaskId, TaskResult, TaskStatus
Expand Down Expand Up @@ -60,16 +60,10 @@ async def get_task_result(
request: Request,
task_id: TaskId,
tasks_manager: Annotated[TasksManager, Depends(get_tasks_manager)],
*,
return_exception: Annotated[bool, Query()] = False,
) -> TaskResult | Any:
assert request # nosec
# TODO: refactor this to use same as in https://github.com/ITISFoundation/osparc-simcore/issues/3265
try:
if return_exception:
task_result = tasks_manager.get_task_result(task_id, with_task_context=None)
else:
task_result = tasks_manager.get_task_result_old(task_id=task_id)
task_result = tasks_manager.get_task_result(task_id, with_task_context=None)
await tasks_manager.remove_task(
task_id, with_task_context=None, reraise_errors=False
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import httpx
from fastapi import status
from models_library.api_schemas_long_running_tasks.base import TaskProgress
from models_library.api_schemas_long_running_tasks.tasks import TaskGet, TaskStatus
from models_library.api_schemas_long_running_tasks.tasks import (
TaskGet,
TaskResult,
TaskStatus,
)
from tenacity import (
AsyncRetrying,
TryAgain,
Expand All @@ -23,7 +27,6 @@
from yarl import URL

from ...long_running_tasks._constants import DEFAULT_POLL_INTERVAL_S, HOUR
from ...long_running_tasks._errors import TaskClientResultError
from ...long_running_tasks._models import (
ClientConfiguration,
LRTask,
Expand All @@ -32,7 +35,7 @@
ProgressPercent,
RequestBody,
)
from ...long_running_tasks._task import TaskId, TaskResult
from ...long_running_tasks._task import TaskId
from ...rest_responses import unwrap_envelope_if_required
from ._client import DEFAULT_HTTP_REQUESTS_TIMEOUT, Client, setup
from ._context_manager import periodic_task_result
Expand Down Expand Up @@ -97,7 +100,7 @@ async def _wait_for_completion(

@retry(**_DEFAULT_FASTAPI_RETRY_POLICY)
async def _task_result(session: httpx.AsyncClient, result_url: URL) -> Any:
response = await session.get(f"{result_url}", params={"return_exception": True})
response = await session.get(f"{result_url}")
response.raise_for_status()
if response.status_code != status.HTTP_204_NO_CONTENT:
return unwrap_envelope_if_required(response.json())
Expand Down Expand Up @@ -155,7 +158,6 @@ async def long_running_task_request(
"ProgressCallback",
"ProgressMessage",
"ProgressPercent",
"TaskClientResultError",
"TaskId",
"TaskResult",
"periodic_task_result",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
running task. The client will take care of recovering the result from it.
"""

from models_library.api_schemas_long_running_tasks.tasks import TaskResult

from ...long_running_tasks._errors import TaskAlreadyRunningError, TaskCancelledError
from ...long_running_tasks._task import (
TaskId,
TaskProgress,
TaskResult,
TasksManager,
TaskStatus,
start_task,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,50 +4,36 @@
class BaseLongRunningError(OsparcErrorMixin, Exception):
"""base exception for this module"""

code: str = "long_running_task.base_long_running_error" # type: ignore[assignment]


class TaskAlreadyRunningError(BaseLongRunningError):
code: str = "long_running_task.task_already_running"
msg_template: str = "{task_name} must be unique, found: '{managed_task}'"


class TaskNotFoundError(BaseLongRunningError):
code: str = "long_running_task.task_not_found"
msg_template: str = "No task with {task_id} found"


class TaskNotCompletedError(BaseLongRunningError):
code: str = "long_running_task.task_not_completed"
msg_template: str = "Task {task_id} has not finished yet"


class TaskCancelledError(BaseLongRunningError):
code: str = "long_running_task.task_cancelled_error"
msg_template: str = "Task {task_id} was cancelled before completing"


class TaskExceptionError(BaseLongRunningError):
code: str = "long_running_task.task_exception_error"
msg_template: str = (
"Task {task_id} finished with exception: '{exception}'\n{traceback}"
)


class TaskClientTimeoutError(BaseLongRunningError):
code: str = "long_running_task.client.timed_out_waiting_for_response"
msg_template: str = (
"Timed out after {timeout} seconds while awaiting '{task_id}' to complete"
)


class GenericClientError(BaseLongRunningError):
code: str = "long_running_task.client.generic_error"
msg_template: str = (
"Unexpected error while '{action}' for '{task_id}': status={status} body={body}"
)


class TaskClientResultError(BaseLongRunningError):
code: str = "long_running_task.client.task_raised_error"
msg_template: str = "{message}"
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TaskNotCompletedError,
TaskNotFoundError,
)
from ._models import TaskId, TaskName, TaskResult, TaskStatus, TrackedTask
from ._models import TaskId, TaskName, TaskStatus, TrackedTask

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -241,36 +241,6 @@ def get_task_result(
# the task was cancelled
raise TaskCancelledError(task_id=task_id) from exc

def get_task_result_old(self, task_id: TaskId) -> TaskResult:
"""
returns: the result of the task

raises TaskNotFoundError if the task cannot be found
"""
tracked_task = self._get_tracked_task(task_id, {})

if not tracked_task.task.done():
raise TaskNotCompletedError(task_id=task_id)

error: TaskExceptionError | TaskCancelledError
try:
exception = tracked_task.task.exception()
if exception is not None:
formatted_traceback = "\n".join(
traceback.format_tb(exception.__traceback__)
)
error = TaskExceptionError(
task_id=task_id, exception=exception, traceback=formatted_traceback
)
logger.warning("Task %s finished with error: %s", task_id, f"{error}")
return TaskResult(result=None, error=f"{error}")
except asyncio.CancelledError:
error = TaskCancelledError(task_id=task_id)
logger.warning("Task %s was cancelled", task_id)
return TaskResult(result=None, error=f"{error}")

return TaskResult(result=tracked_task.task.result(), error=None)

async def cancel_task(
self, task_id: TaskId, with_task_context: TaskContext | None
) -> None:
Expand Down Expand Up @@ -354,12 +324,12 @@ async def close(self) -> None:


class TaskProtocol(Protocol):
async def __call__(self, progress: TaskProgress, *args: Any, **kwargs: Any) -> Any:
...
async def __call__(
self, progress: TaskProgress, *args: Any, **kwargs: Any
) -> Any: ...

@property
def __name__(self) -> str:
...
def __name__(self) -> str: ...


def start_task(
Expand Down Expand Up @@ -449,7 +419,5 @@ async def _progress_task(progress: TaskProgress, handler: TaskProtocol):
"TaskProgress",
"TaskProtocol",
"TaskStatus",
"TaskResult",
"TrackedTask",
"TaskResult",
)
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,9 @@ async def _caller(app: FastAPI, client: AsyncClient, **query_kwargs) -> TaskId:


@pytest.fixture
def wait_for_task() -> Callable[
[FastAPI, AsyncClient, TaskId, TaskContext], Awaitable[None]
]:
def wait_for_task() -> (
Callable[[FastAPI, AsyncClient, TaskId, TaskContext], Awaitable[None]]
):
async def _waiter(
app: FastAPI,
client: AsyncClient,
Expand Down Expand Up @@ -183,9 +183,7 @@ async def test_workflow(
result = await client.get(f"{result_url}")
# NOTE: this is DIFFERENT than with aiohttp where we return the real result
assert result.status_code == status.HTTP_200_OK
task_result = long_running_tasks.server.TaskResult.model_validate(result.json())
assert not task_result.error
assert task_result.result == [f"{x}" for x in range(10)]
assert result.json() == [f"{x}" for x in range(10)]
# getting the result again should raise a 404
result = await client.get(result_url)
assert result.status_code == status.HTTP_404_NOT_FOUND
Expand Down Expand Up @@ -220,19 +218,9 @@ async def test_failing_task_returns_error(
await wait_for_task(app, client, task_id, {})
# get the result
result_url = app.url_path_for("get_task_result", task_id=task_id)
result = await client.get(f"{result_url}")
assert result.status_code == status.HTTP_200_OK
task_result = long_running_tasks.server.TaskResult.model_validate(result.json())

assert not task_result.result
assert task_result.error
assert task_result.error.startswith(f"Task {task_id} finished with exception: ")
assert 'raise RuntimeError("We were asked to fail!!")' in task_result.error
# NOTE: this is not yet happening with fastapi version of long running task
# assert "errors" in task_result.error
# assert len(task_result.error["errors"]) == 1
# assert task_result.error["errors"][0]["code"] == "RuntimeError"
# assert task_result.error["errors"][0]["message"] == "We were asked to fail!!"
with pytest.raises(RuntimeError) as exec_info:
await client.get(f"{result_url}")
assert f"{exec_info.value}" == "We were asked to fail!!"


async def test_get_results_before_tasks_finishes_returns_404(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
# pylint: disable=unused-argument

import asyncio
from typing import AsyncIterable, Final
from collections.abc import AsyncIterable
from typing import Final

import pytest
from asgi_lifespan import LifespanManager
Expand All @@ -24,9 +25,10 @@
get_tasks_manager,
)
from servicelib.fastapi.long_running_tasks.server import setup as setup_server
from servicelib.fastapi.long_running_tasks.server import start_task
from servicelib.fastapi.long_running_tasks.server import (
start_task,
)
from servicelib.long_running_tasks._errors import (
TaskClientResultError,
TaskClientTimeoutError,
)

Expand Down Expand Up @@ -149,16 +151,14 @@ async def test_task_result_task_result_is_an_error(

url = TypeAdapter(AnyHttpUrl).validate_python("http://backgroud.testserver.io/")
client = Client(app=bg_task_app, async_client=async_client, base_url=url)
with pytest.raises(TaskClientResultError) as exec_info:
with pytest.raises(RuntimeError, match="I am failing as requested"):
async with periodic_task_result(
client,
task_id,
task_timeout=10,
status_poll_interval=TASK_SLEEP_INTERVAL / 3,
):
pass
assert f"{exec_info.value}".startswith(f"Task {task_id} finished with exception:")
assert "I am failing as requested" in f"{exec_info.value}"
await _assert_task_removed(async_client, task_id, router_prefix)


Expand Down
Loading
Loading