Skip to content

Commit b478207

Browse files
authored
🐛Storage: fix copying of file and ensure project lock release is always notified safely (#7206)
1 parent 547ee94 commit b478207

File tree

9 files changed

+87
-94
lines changed

9 files changed

+87
-94
lines changed
Lines changed: 19 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,40 @@
11
import asyncio
2-
from collections.abc import AsyncGenerator, Coroutine
3-
from dataclasses import dataclass
4-
from typing import Any, Final, TypeAlias
2+
import logging
3+
from collections.abc import AsyncGenerator
4+
from typing import Any
55

66
from aiohttp import ClientConnectionError, ClientSession
77
from tenacity import TryAgain, retry
88
from tenacity.asyncio import AsyncRetrying
9+
from tenacity.before_sleep import before_sleep_log
910
from tenacity.retry import retry_if_exception_type
1011
from tenacity.stop import stop_after_delay
1112
from tenacity.wait import wait_random_exponential
1213
from yarl import URL
1314

14-
from ...rest_responses import unwrap_envelope
15+
from ...long_running_tasks._constants import DEFAULT_POLL_INTERVAL_S, HOUR
16+
from ...long_running_tasks._models import LRTask, RequestBody
17+
from ...rest_responses import unwrap_envelope_if_required
1518
from .. import status
1619
from .server import TaskGet, TaskId, TaskProgress, TaskStatus
1720

18-
RequestBody: TypeAlias = Any
21+
_logger = logging.getLogger(__name__)
22+
1923

20-
_MINUTE: Final[int] = 60 # in secs
21-
_HOUR: Final[int] = 60 * _MINUTE # in secs
22-
_DEFAULT_POLL_INTERVAL_S: Final[float] = 1
2324
_DEFAULT_AIOHTTP_RETRY_POLICY: dict[str, Any] = {
2425
"retry": retry_if_exception_type(ClientConnectionError),
2526
"wait": wait_random_exponential(max=20),
2627
"stop": stop_after_delay(60),
2728
"reraise": True,
29+
"before_sleep": before_sleep_log(_logger, logging.INFO),
2830
}
2931

3032

3133
@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
3234
async def _start(session: ClientSession, url: URL, json: RequestBody | None) -> TaskGet:
3335
async with session.post(url, json=json) as response:
3436
response.raise_for_status()
35-
data, error = unwrap_envelope(await response.json())
36-
assert not error # nosec
37-
assert data is not None # nosec
37+
data = unwrap_envelope_if_required(await response.json())
3838
return TaskGet.model_validate(data)
3939

4040

@@ -50,21 +50,18 @@ async def _wait_for_completion(
5050
stop=stop_after_delay(client_timeout),
5151
reraise=True,
5252
retry=retry_if_exception_type(TryAgain),
53+
before_sleep=before_sleep_log(_logger, logging.DEBUG),
5354
):
5455
with attempt:
5556
async with session.get(status_url) as response:
5657
response.raise_for_status()
57-
data, error = unwrap_envelope(await response.json())
58-
assert not error # nosec
59-
assert data is not None # nosec
58+
data = unwrap_envelope_if_required(await response.json())
6059
task_status = TaskStatus.model_validate(data)
6160
yield task_status.task_progress
6261
if not task_status.done:
6362
await asyncio.sleep(
6463
float(
65-
response.headers.get(
66-
"retry-after", _DEFAULT_POLL_INTERVAL_S
67-
)
64+
response.headers.get("retry-after", DEFAULT_POLL_INTERVAL_S)
6865
)
6966
)
7067
msg = f"{task_id=}, {task_status.started=} has status: '{task_status.task_progress.message}' {task_status.task_progress.percent}%"
@@ -81,42 +78,21 @@ async def _task_result(session: ClientSession, result_url: URL) -> Any:
8178
async with session.get(result_url) as response:
8279
response.raise_for_status()
8380
if response.status != status.HTTP_204_NO_CONTENT:
84-
data, error = unwrap_envelope(await response.json())
85-
assert not error # nosec
86-
assert data # nosec
87-
return data
81+
return unwrap_envelope_if_required(await response.json())
8882
return None
8983

9084

9185
@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
9286
async def _abort_task(session: ClientSession, abort_url: URL) -> None:
9387
async with session.delete(abort_url) as response:
9488
response.raise_for_status()
95-
data, error = unwrap_envelope(await response.json())
96-
assert not error # nosec
97-
assert not data # nosec
98-
99-
100-
@dataclass(frozen=True)
101-
class LRTask:
102-
progress: TaskProgress
103-
_result: Coroutine[Any, Any, Any] | None = None
104-
105-
def done(self) -> bool:
106-
return self._result is not None
107-
108-
async def result(self) -> Any:
109-
if not self._result:
110-
msg = "No result ready!"
111-
raise ValueError(msg)
112-
return await self._result
11389

11490

11591
async def long_running_task_request(
11692
session: ClientSession,
11793
url: URL,
11894
json: RequestBody | None = None,
119-
client_timeout: int = 1 * _HOUR,
95+
client_timeout: int = 1 * HOUR,
12096
) -> AsyncGenerator[LRTask, None]:
12197
"""Will use the passed `ClientSession` to call an oSparc long
12298
running task `url` passing `json` as request body.
@@ -147,3 +123,6 @@ async def long_running_task_request(
147123
if task:
148124
await _abort_task(session, URL(task.abort_href))
149125
raise
126+
127+
128+
__all__: tuple[str, ...] = ("LRTask",)

packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ async def list_tasks(
3131

3232
@router.get(
3333
"/{task_id}",
34+
response_model=TaskStatus,
3435
responses={
3536
status.HTTP_404_NOT_FOUND: {"description": "Task does not exist"},
3637
},

packages/service-library/src/servicelib/fastapi/long_running_tasks/client.py

Lines changed: 19 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
"""
44

55
import asyncio
6-
from collections.abc import AsyncGenerator, Coroutine
7-
from dataclasses import dataclass
8-
from typing import Any, Final, TypeAlias
6+
import logging
7+
from collections.abc import AsyncGenerator
8+
from typing import Any
99

1010
import httpx
1111
from fastapi import status
@@ -14,39 +14,42 @@
1414
from tenacity import (
1515
AsyncRetrying,
1616
TryAgain,
17+
before_sleep_log,
1718
retry,
1819
retry_if_exception_type,
1920
stop_after_delay,
2021
wait_random_exponential,
2122
)
2223
from yarl import URL
2324

25+
from ...long_running_tasks._constants import DEFAULT_POLL_INTERVAL_S, HOUR
2426
from ...long_running_tasks._errors import TaskClientResultError
2527
from ...long_running_tasks._models import (
2628
ClientConfiguration,
29+
LRTask,
2730
ProgressCallback,
2831
ProgressMessage,
2932
ProgressPercent,
33+
RequestBody,
3034
)
3135
from ...long_running_tasks._task import TaskId, TaskResult
3236
from ...rest_responses import unwrap_envelope_if_required
3337
from ._client import DEFAULT_HTTP_REQUESTS_TIMEOUT, Client, setup
3438
from ._context_manager import periodic_task_result
3539

36-
RequestBody: TypeAlias = Any
40+
_logger = logging.getLogger(__name__)
3741

38-
_MINUTE: Final[int] = 60 # in secs
39-
_HOUR: Final[int] = 60 * _MINUTE # in secs
40-
_DEFAULT_POLL_INTERVAL_S: Final[float] = 1
41-
_DEFAULT_AIOHTTP_RETRY_POLICY: dict[str, Any] = {
42+
43+
_DEFAULT_FASTAPI_RETRY_POLICY: dict[str, Any] = {
4244
"retry": retry_if_exception_type(httpx.RequestError),
4345
"wait": wait_random_exponential(max=20),
4446
"stop": stop_after_delay(60),
4547
"reraise": True,
48+
"before_sleep": before_sleep_log(_logger, logging.INFO),
4649
}
4750

4851

49-
@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
52+
@retry(**_DEFAULT_FASTAPI_RETRY_POLICY)
5053
async def _start(
5154
session: httpx.AsyncClient, url: URL, json: RequestBody | None
5255
) -> TaskGet:
@@ -56,7 +59,7 @@ async def _start(
5659
return TaskGet.model_validate(data)
5760

5861

59-
@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
62+
@retry(**_DEFAULT_FASTAPI_RETRY_POLICY)
6063
async def _wait_for_completion(
6164
session: httpx.AsyncClient,
6265
task_id: TaskId,
@@ -68,6 +71,7 @@ async def _wait_for_completion(
6871
stop=stop_after_delay(client_timeout),
6972
reraise=True,
7073
retry=retry_if_exception_type(TryAgain),
74+
before_sleep=before_sleep_log(_logger, logging.DEBUG),
7175
):
7276
with attempt:
7377
response = await session.get(f"{status_url}")
@@ -79,9 +83,7 @@ async def _wait_for_completion(
7983
if not task_status.done:
8084
await asyncio.sleep(
8185
float(
82-
response.headers.get(
83-
"retry-after", _DEFAULT_POLL_INTERVAL_S
84-
)
86+
response.headers.get("retry-after", DEFAULT_POLL_INTERVAL_S)
8587
)
8688
)
8789
msg = f"{task_id=}, {task_status.started=} has status: '{task_status.task_progress.message}' {task_status.task_progress.percent}%"
@@ -93,7 +95,7 @@ async def _wait_for_completion(
9395
raise TimeoutError(msg) from exc
9496

9597

96-
@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
98+
@retry(**_DEFAULT_FASTAPI_RETRY_POLICY)
9799
async def _task_result(session: httpx.AsyncClient, result_url: URL) -> Any:
98100
response = await session.get(f"{result_url}", params={"return_exception": True})
99101
response.raise_for_status()
@@ -102,32 +104,17 @@ async def _task_result(session: httpx.AsyncClient, result_url: URL) -> Any:
102104
return None
103105

104106

105-
@retry(**_DEFAULT_AIOHTTP_RETRY_POLICY)
107+
@retry(**_DEFAULT_FASTAPI_RETRY_POLICY)
106108
async def _abort_task(session: httpx.AsyncClient, abort_url: URL) -> None:
107109
response = await session.delete(f"{abort_url}")
108110
response.raise_for_status()
109111

110112

111-
@dataclass(frozen=True)
112-
class LRTask:
113-
progress: TaskProgress
114-
_result: Coroutine[Any, Any, Any] | None = None
115-
116-
def done(self) -> bool:
117-
return self._result is not None
118-
119-
async def result(self) -> Any:
120-
if not self._result:
121-
msg = "No result ready!"
122-
raise ValueError(msg)
123-
return await self._result
124-
125-
126113
async def long_running_task_request(
127114
session: httpx.AsyncClient,
128115
url: URL,
129116
json: RequestBody | None = None,
130-
client_timeout: int = 1 * _HOUR,
117+
client_timeout: int = 1 * HOUR,
131118
) -> AsyncGenerator[LRTask, None]:
132119
"""Will use the passed `httpx.AsyncClient` to call an oSparc long
133120
running task `url` passing `json` as request body.
@@ -164,6 +151,7 @@ async def long_running_task_request(
164151
"DEFAULT_HTTP_REQUESTS_TIMEOUT",
165152
"Client",
166153
"ClientConfiguration",
154+
"LRTask",
167155
"ProgressCallback",
168156
"ProgressMessage",
169157
"ProgressPercent",
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from typing import Final
2+
3+
MINUTE: Final[int] = 60 # in secs
4+
HOUR: Final[int] = 60 * MINUTE # in secs
5+
DEFAULT_POLL_INTERVAL_S: Final[float] = 1

packages/service-library/src/servicelib/long_running_tasks/_models.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# mypy: disable-error-code=truthy-function
22
from asyncio import Task
33
from collections.abc import Awaitable, Callable, Coroutine
4+
from dataclasses import dataclass
45
from datetime import datetime
56
from typing import Any, TypeAlias
67

@@ -25,6 +26,8 @@
2526
[ProgressMessage, ProgressPercent | None, TaskId], Awaitable[None]
2627
]
2728

29+
RequestBody: TypeAlias = Any
30+
2831

2932
class TrackedTask(BaseModel):
3033
task_id: str
@@ -56,18 +59,33 @@ class ClientConfiguration(BaseModel):
5659
default_timeout: PositiveFloat
5760

5861

62+
@dataclass(frozen=True)
63+
class LRTask:
64+
progress: TaskProgress
65+
_result: Coroutine[Any, Any, Any] | None = None
66+
67+
def done(self) -> bool:
68+
return self._result is not None
69+
70+
async def result(self) -> Any:
71+
if not self._result:
72+
msg = "No result ready!"
73+
raise ValueError(msg)
74+
return await self._result
75+
76+
5977
# explicit export of models for api-schemas
6078

6179
assert TaskResult # nosec
6280
assert TaskGet # nosec
6381
assert TaskStatus # nosec
6482

6583
__all__: tuple[str, ...] = (
84+
"ProgressMessage",
85+
"ProgressPercent",
6686
"TaskGet",
6787
"TaskId",
88+
"TaskProgress",
6889
"TaskResult",
6990
"TaskStatus",
70-
"TaskProgress",
71-
"ProgressPercent",
72-
"ProgressMessage",
7391
)

packages/service-library/src/servicelib/redis/_project_lock.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
import functools
2+
import logging
23
from collections.abc import Awaitable, Callable, Coroutine
34
from typing import Any, Final, ParamSpec, TypeVar
45

56
from models_library.projects import ProjectID
67
from models_library.projects_access import Owner
78
from models_library.projects_state import ProjectLocked, ProjectStatus
9+
from servicelib.logging_utils import log_catch
810

911
from ._client import RedisClientSDK
1012
from ._decorators import exclusive
1113
from ._errors import CouldNotAcquireLockError, ProjectLockError
1214

1315
_PROJECT_REDIS_LOCK_KEY: Final[str] = "project_lock:{}"
1416

17+
_logger = logging.getLogger(__name__)
1518

1619
P = ParamSpec("P")
1720
R = TypeVar("R")
@@ -59,17 +62,20 @@ async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
5962
)
6063
async def _exclusive_func(*args, **kwargs) -> R:
6164
if notification_cb is not None:
62-
await notification_cb()
65+
with log_catch(_logger, reraise=False):
66+
await notification_cb()
6367
return await func(*args, **kwargs)
6468

6569
try:
66-
result = await _exclusive_func(*args, **kwargs)
67-
# we are now unlocked
68-
if notification_cb is not None:
69-
await notification_cb()
70-
return result
70+
return await _exclusive_func(*args, **kwargs)
71+
7172
except CouldNotAcquireLockError as e:
7273
raise ProjectLockError from e
74+
finally:
75+
# we are now unlocked
76+
if notification_cb is not None:
77+
with log_catch(_logger, reraise=False):
78+
await notification_cb()
7379

7480
return _wrapper
7581

0 commit comments

Comments
 (0)