Skip to content
Original file line number Diff line number Diff line change
@@ -1,40 +1,40 @@
import asyncio
from collections.abc import AsyncGenerator, Coroutine
from dataclasses import dataclass
from typing import Any, Final, TypeAlias
import logging
from collections.abc import AsyncGenerator
from typing import Any

from aiohttp import ClientConnectionError, ClientSession
from tenacity import TryAgain, retry
from tenacity.asyncio import AsyncRetrying
from tenacity.before_sleep import before_sleep_log
from tenacity.retry import retry_if_exception_type
from tenacity.stop import stop_after_delay
from tenacity.wait import wait_random_exponential
from yarl import URL

from ...rest_responses import unwrap_envelope
from ...long_running_tasks._constants import DEFAULT_POLL_INTERVAL_S, HOUR
from ...long_running_tasks._models import LRTask, RequestBody
from ...rest_responses import unwrap_envelope_if_required
from .. import status
from .server import TaskGet, TaskId, TaskProgress, TaskStatus

RequestBody: TypeAlias = Any
_logger = logging.getLogger(__name__)


_MINUTE: Final[int] = 60 # in secs
_HOUR: Final[int] = 60 * _MINUTE # in secs
_DEFAULT_POLL_INTERVAL_S: Final[float] = 1
_DEFAULT_AIOHTTP_RETRY_POLICY: dict[str, Any] = {
"retry": retry_if_exception_type(ClientConnectionError),
"wait": wait_random_exponential(max=20),
"stop": stop_after_delay(60),
"reraise": True,
"before_sleep": before_sleep_log(_logger, logging.INFO),
}


@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
async def _start(session: ClientSession, url: URL, json: RequestBody | None) -> TaskGet:
async with session.post(url, json=json) as response:
response.raise_for_status()
data, error = unwrap_envelope(await response.json())
assert not error # nosec
assert data is not None # nosec
data = unwrap_envelope_if_required(await response.json())
return TaskGet.model_validate(data)


Expand All @@ -50,21 +50,18 @@ async def _wait_for_completion(
stop=stop_after_delay(client_timeout),
reraise=True,
retry=retry_if_exception_type(TryAgain),
before_sleep=before_sleep_log(_logger, logging.DEBUG),
):
with attempt:
async with session.get(status_url) as response:
response.raise_for_status()
data, error = unwrap_envelope(await response.json())
assert not error # nosec
assert data is not None # nosec
data = unwrap_envelope_if_required(await response.json())
task_status = TaskStatus.model_validate(data)
yield task_status.task_progress
if not task_status.done:
await asyncio.sleep(
float(
response.headers.get(
"retry-after", _DEFAULT_POLL_INTERVAL_S
)
response.headers.get("retry-after", DEFAULT_POLL_INTERVAL_S)
)
)
msg = f"{task_id=}, {task_status.started=} has status: '{task_status.task_progress.message}' {task_status.task_progress.percent}%"
Expand All @@ -81,42 +78,21 @@ async def _task_result(session: ClientSession, result_url: URL) -> Any:
async with session.get(result_url) as response:
response.raise_for_status()
if response.status != status.HTTP_204_NO_CONTENT:
data, error = unwrap_envelope(await response.json())
assert not error # nosec
assert data # nosec
return data
return unwrap_envelope_if_required(await response.json())
return None


@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
async def _abort_task(session: ClientSession, abort_url: URL) -> None:
async with session.delete(abort_url) as response:
response.raise_for_status()
data, error = unwrap_envelope(await response.json())
assert not error # nosec
assert not data # nosec


@dataclass(frozen=True)
class LRTask:
progress: TaskProgress
_result: Coroutine[Any, Any, Any] | None = None

def done(self) -> bool:
return self._result is not None

async def result(self) -> Any:
if not self._result:
msg = "No result ready!"
raise ValueError(msg)
return await self._result


async def long_running_task_request(
session: ClientSession,
url: URL,
json: RequestBody | None = None,
client_timeout: int = 1 * _HOUR,
client_timeout: int = 1 * HOUR,
) -> AsyncGenerator[LRTask, None]:
"""Will use the passed `ClientSession` to call an oSparc long
running task `url` passing `json` as request body.
Expand Down Expand Up @@ -147,3 +123,6 @@ async def long_running_task_request(
if task:
await _abort_task(session, URL(task.abort_href))
raise


__all__: tuple[str, ...] = ("LRTask",)
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ async def list_tasks(

@router.get(
"/{task_id}",
response_model=TaskStatus,
responses={
status.HTTP_404_NOT_FOUND: {"description": "Task does not exist"},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
"""

import asyncio
from collections.abc import AsyncGenerator, Coroutine
from dataclasses import dataclass
from typing import Any, Final, TypeAlias
import logging
from collections.abc import AsyncGenerator
from typing import Any

import httpx
from fastapi import status
Expand All @@ -14,39 +14,42 @@
from tenacity import (
AsyncRetrying,
TryAgain,
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_delay,
wait_random_exponential,
)
from yarl import URL

from ...long_running_tasks._constants import DEFAULT_POLL_INTERVAL_S, HOUR
from ...long_running_tasks._errors import TaskClientResultError
from ...long_running_tasks._models import (
ClientConfiguration,
LRTask,
ProgressCallback,
ProgressMessage,
ProgressPercent,
RequestBody,
)
from ...long_running_tasks._task import TaskId, TaskResult
from ...rest_responses import unwrap_envelope_if_required
from ._client import DEFAULT_HTTP_REQUESTS_TIMEOUT, Client, setup
from ._context_manager import periodic_task_result

RequestBody: TypeAlias = Any
_logger = logging.getLogger(__name__)

_MINUTE: Final[int] = 60 # in secs
_HOUR: Final[int] = 60 * _MINUTE # in secs
_DEFAULT_POLL_INTERVAL_S: Final[float] = 1
_DEFAULT_AIOHTTP_RETRY_POLICY: dict[str, Any] = {

_DEFAULT_FASTAPI_RETRY_POLICY: dict[str, Any] = {
"retry": retry_if_exception_type(httpx.RequestError),
"wait": wait_random_exponential(max=20),
"stop": stop_after_delay(60),
"reraise": True,
"before_sleep": before_sleep_log(_logger, logging.INFO),
}


@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
@retry(**_DEFAULT_FASTAPI_RETRY_POLICY)
async def _start(
session: httpx.AsyncClient, url: URL, json: RequestBody | None
) -> TaskGet:
Expand All @@ -56,7 +59,7 @@ async def _start(
return TaskGet.model_validate(data)


@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
@retry(**_DEFAULT_FASTAPI_RETRY_POLICY)
async def _wait_for_completion(
session: httpx.AsyncClient,
task_id: TaskId,
Expand All @@ -68,6 +71,7 @@ async def _wait_for_completion(
stop=stop_after_delay(client_timeout),
reraise=True,
retry=retry_if_exception_type(TryAgain),
before_sleep=before_sleep_log(_logger, logging.DEBUG),
):
with attempt:
response = await session.get(f"{status_url}")
Expand All @@ -79,9 +83,7 @@ async def _wait_for_completion(
if not task_status.done:
await asyncio.sleep(
float(
response.headers.get(
"retry-after", _DEFAULT_POLL_INTERVAL_S
)
response.headers.get("retry-after", DEFAULT_POLL_INTERVAL_S)
)
)
msg = f"{task_id=}, {task_status.started=} has status: '{task_status.task_progress.message}' {task_status.task_progress.percent}%"
Expand All @@ -93,7 +95,7 @@ async def _wait_for_completion(
raise TimeoutError(msg) from exc


@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
@retry(**_DEFAULT_FASTAPI_RETRY_POLICY)
async def _task_result(session: httpx.AsyncClient, result_url: URL) -> Any:
response = await session.get(f"{result_url}", params={"return_exception": True})
response.raise_for_status()
Expand All @@ -102,32 +104,17 @@ async def _task_result(session: httpx.AsyncClient, result_url: URL) -> Any:
return None


@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
@retry(**_DEFAULT_FASTAPI_RETRY_POLICY)
async def _abort_task(session: httpx.AsyncClient, abort_url: URL) -> None:
response = await session.delete(f"{abort_url}")
response.raise_for_status()


@dataclass(frozen=True)
class LRTask:
progress: TaskProgress
_result: Coroutine[Any, Any, Any] | None = None

def done(self) -> bool:
return self._result is not None

async def result(self) -> Any:
if not self._result:
msg = "No result ready!"
raise ValueError(msg)
return await self._result


async def long_running_task_request(
session: httpx.AsyncClient,
url: URL,
json: RequestBody | None = None,
client_timeout: int = 1 * _HOUR,
client_timeout: int = 1 * HOUR,
) -> AsyncGenerator[LRTask, None]:
"""Will use the passed `httpx.AsyncClient` to call an oSparc long
running task `url` passing `json` as request body.
Expand Down Expand Up @@ -164,6 +151,7 @@ async def long_running_task_request(
"DEFAULT_HTTP_REQUESTS_TIMEOUT",
"Client",
"ClientConfiguration",
"LRTask",
"ProgressCallback",
"ProgressMessage",
"ProgressPercent",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from typing import Final

MINUTE: Final[int] = 60 # in secs
HOUR: Final[int] = 60 * MINUTE # in secs
DEFAULT_POLL_INTERVAL_S: Final[float] = 1
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# mypy: disable-error-code=truthy-function
from asyncio import Task
from collections.abc import Awaitable, Callable, Coroutine
from dataclasses import dataclass
from datetime import datetime
from typing import Any, TypeAlias

Expand All @@ -25,6 +26,8 @@
[ProgressMessage, ProgressPercent | None, TaskId], Awaitable[None]
]

RequestBody: TypeAlias = Any


class TrackedTask(BaseModel):
task_id: str
Expand Down Expand Up @@ -56,18 +59,33 @@ class ClientConfiguration(BaseModel):
default_timeout: PositiveFloat


@dataclass(frozen=True)
class LRTask:
progress: TaskProgress
_result: Coroutine[Any, Any, Any] | None = None

def done(self) -> bool:
return self._result is not None

async def result(self) -> Any:
if not self._result:
msg = "No result ready!"
raise ValueError(msg)
return await self._result


# explicit export of models for api-schemas

assert TaskResult # nosec
assert TaskGet # nosec
assert TaskStatus # nosec

__all__: tuple[str, ...] = (
"ProgressMessage",
"ProgressPercent",
"TaskGet",
"TaskId",
"TaskProgress",
"TaskResult",
"TaskStatus",
"TaskProgress",
"ProgressPercent",
"ProgressMessage",
)
18 changes: 12 additions & 6 deletions packages/service-library/src/servicelib/redis/_project_lock.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import functools
import logging
from collections.abc import Awaitable, Callable, Coroutine
from typing import Any, Final, ParamSpec, TypeVar

from models_library.projects import ProjectID
from models_library.projects_access import Owner
from models_library.projects_state import ProjectLocked, ProjectStatus
from servicelib.logging_utils import log_catch

from ._client import RedisClientSDK
from ._decorators import exclusive
from ._errors import CouldNotAcquireLockError, ProjectLockError

_PROJECT_REDIS_LOCK_KEY: Final[str] = "project_lock:{}"

_logger = logging.getLogger(__name__)

P = ParamSpec("P")
R = TypeVar("R")
Expand Down Expand Up @@ -59,17 +62,20 @@ async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
)
async def _exclusive_func(*args, **kwargs) -> R:
if notification_cb is not None:
await notification_cb()
with log_catch(_logger, reraise=False):
await notification_cb()
return await func(*args, **kwargs)

try:
result = await _exclusive_func(*args, **kwargs)
# we are now unlocked
if notification_cb is not None:
await notification_cb()
return result
return await _exclusive_func(*args, **kwargs)

except CouldNotAcquireLockError as e:
raise ProjectLockError from e
finally:
# we are now unlocked
if notification_cb is not None:
with log_catch(_logger, reraise=False):
await notification_cb()

return _wrapper

Expand Down
Loading
Loading