Skip to content

Commit 477cf6f

Browse files
committed
de-duplicate
1 parent 2de0b94 commit 477cf6f

File tree

5 files changed

+43
-58
lines changed

5 files changed

+43
-58
lines changed

packages/service-library/src/servicelib/aiohttp/long_running_tasks/client.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import asyncio
22
import logging
3-
from collections.abc import AsyncGenerator, Coroutine
4-
from dataclasses import dataclass
5-
from typing import Any, Final, TypeAlias
3+
from collections.abc import AsyncGenerator
4+
from typing import Any
65

76
from aiohttp import ClientConnectionError, ClientSession
87
from tenacity import TryAgain, retry
@@ -13,17 +12,15 @@
1312
from tenacity.wait import wait_random_exponential
1413
from yarl import URL
1514

15+
from ...long_running_tasks._constants import DEFAULT_POLL_INTERVAL_S, HOUR
16+
from ...long_running_tasks._models import LRTask, RequestBody
1617
from ...rest_responses import unwrap_envelope_if_required
1718
from .. import status
1819
from .server import TaskGet, TaskId, TaskProgress, TaskStatus
1920

2021
_logger = logging.getLogger(__name__)
2122

22-
RequestBody: TypeAlias = Any
2323

24-
_MINUTE: Final[int] = 60 # in secs
25-
_HOUR: Final[int] = 60 * _MINUTE # in secs
26-
_DEFAULT_POLL_INTERVAL_S: Final[float] = 1
2724
_DEFAULT_AIOHTTP_RETRY_POLICY: dict[str, Any] = {
2825
"retry": retry_if_exception_type(ClientConnectionError),
2926
"wait": wait_random_exponential(max=20),
@@ -64,9 +61,7 @@ async def _wait_for_completion(
6461
if not task_status.done:
6562
await asyncio.sleep(
6663
float(
67-
response.headers.get(
68-
"retry-after", _DEFAULT_POLL_INTERVAL_S
69-
)
64+
response.headers.get("retry-after", DEFAULT_POLL_INTERVAL_S)
7065
)
7166
)
7267
msg = f"{task_id=}, {task_status.started=} has status: '{task_status.task_progress.message}' {task_status.task_progress.percent}%"
@@ -93,26 +88,11 @@ async def _abort_task(session: ClientSession, abort_url: URL) -> None:
9388
response.raise_for_status()
9489

9590

96-
@dataclass(frozen=True)
97-
class LRTask:
98-
progress: TaskProgress
99-
_result: Coroutine[Any, Any, Any] | None = None
100-
101-
def done(self) -> bool:
102-
return self._result is not None
103-
104-
async def result(self) -> Any:
105-
if not self._result:
106-
msg = "No result ready!"
107-
raise ValueError(msg)
108-
return await self._result
109-
110-
11191
async def long_running_task_request(
11292
session: ClientSession,
11393
url: URL,
11494
json: RequestBody | None = None,
115-
client_timeout: int = 1 * _HOUR,
95+
client_timeout: int = 1 * HOUR,
11696
) -> AsyncGenerator[LRTask, None]:
11797
"""Will use the passed `ClientSession` to call an oSparc long
11898
running task `url` passing `json` as request body.

packages/service-library/src/servicelib/fastapi/long_running_tasks/client.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import logging
77
from collections.abc import AsyncGenerator, Coroutine
88
from dataclasses import dataclass
9-
from typing import Any, Final, TypeAlias
9+
from typing import Any
1010

1111
import httpx
1212
from fastapi import status
@@ -23,12 +23,15 @@
2323
)
2424
from yarl import URL
2525

26+
from ...long_running_tasks._constants import DEFAULT_POLL_INTERVAL_S, HOUR
2627
from ...long_running_tasks._errors import TaskClientResultError
2728
from ...long_running_tasks._models import (
2829
ClientConfiguration,
30+
LRTask,
2931
ProgressCallback,
3032
ProgressMessage,
3133
ProgressPercent,
34+
RequestBody,
3235
)
3336
from ...long_running_tasks._task import TaskId, TaskResult
3437
from ...rest_responses import unwrap_envelope_if_required
@@ -37,11 +40,7 @@
3740

3841
_logger = logging.getLogger(__name__)
3942

40-
RequestBody: TypeAlias = Any
4143

42-
_MINUTE: Final[int] = 60 # in secs
43-
_HOUR: Final[int] = 60 * _MINUTE # in secs
44-
_DEFAULT_POLL_INTERVAL_S: Final[float] = 1
4544
_DEFAULT_FASTAPI_RETRY_POLICY: dict[str, Any] = {
4645
"retry": retry_if_exception_type(httpx.RequestError),
4746
"wait": wait_random_exponential(max=20),
@@ -85,9 +84,7 @@ async def _wait_for_completion(
8584
if not task_status.done:
8685
await asyncio.sleep(
8786
float(
88-
response.headers.get(
89-
"retry-after", _DEFAULT_POLL_INTERVAL_S
90-
)
87+
response.headers.get("retry-after", DEFAULT_POLL_INTERVAL_S)
9188
)
9289
)
9390
msg = f"{task_id=}, {task_status.started=} has status: '{task_status.task_progress.message}' {task_status.task_progress.percent}%"
@@ -114,26 +111,11 @@ async def _abort_task(session: httpx.AsyncClient, abort_url: URL) -> None:
114111
response.raise_for_status()
115112

116113

117-
@dataclass(frozen=True)
118-
class LRTask:
119-
progress: TaskProgress
120-
_result: Coroutine[Any, Any, Any] | None = None
121-
122-
def done(self) -> bool:
123-
return self._result is not None
124-
125-
async def result(self) -> Any:
126-
if not self._result:
127-
msg = "No result ready!"
128-
raise ValueError(msg)
129-
return await self._result
130-
131-
132114
async def long_running_task_request(
133115
session: httpx.AsyncClient,
134116
url: URL,
135117
json: RequestBody | None = None,
136-
client_timeout: int = 1 * _HOUR,
118+
client_timeout: int = 1 * HOUR,
137119
) -> AsyncGenerator[LRTask, None]:
138120
"""Will use the passed `httpx.AsyncClient` to call an oSparc long
139121
running task `url` passing `json` as request body.
@@ -167,16 +149,17 @@ async def long_running_task_request(
167149

168150

169151
__all__: tuple[str, ...] = (
170-
"DEFAULT_HTTP_REQUESTS_TIMEOUT",
171152
"Client",
172153
"ClientConfiguration",
154+
"DEFAULT_HTTP_REQUESTS_TIMEOUT",
155+
"LRTask",
156+
"periodic_task_result",
173157
"ProgressCallback",
174158
"ProgressMessage",
175159
"ProgressPercent",
160+
"setup",
176161
"TaskClientResultError",
177162
"TaskId",
178163
"TaskResult",
179-
"periodic_task_result",
180-
"setup",
181164
)
182165
# nopycln: file
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from typing import Final
2+
3+
MINUTE: Final[int] = 60 # in secs
4+
HOUR: Final[int] = 60 * MINUTE # in secs
5+
DEFAULT_POLL_INTERVAL_S: Final[float] = 1

packages/service-library/src/servicelib/long_running_tasks/_models.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# mypy: disable-error-code=truthy-function
22
from asyncio import Task
33
from collections.abc import Awaitable, Callable, Coroutine
4+
from dataclasses import dataclass
45
from datetime import datetime
56
from typing import Any, TypeAlias
67

@@ -25,6 +26,8 @@
2526
[ProgressMessage, ProgressPercent | None, TaskId], Awaitable[None]
2627
]
2728

29+
RequestBody: TypeAlias = Any
30+
2831

2932
class TrackedTask(BaseModel):
3033
task_id: str
@@ -56,18 +59,33 @@ class ClientConfiguration(BaseModel):
5659
default_timeout: PositiveFloat
5760

5861

62+
@dataclass(frozen=True)
63+
class LRTask:
64+
progress: TaskProgress
65+
_result: Coroutine[Any, Any, Any] | None = None
66+
67+
def done(self) -> bool:
68+
return self._result is not None
69+
70+
async def result(self) -> Any:
71+
if not self._result:
72+
msg = "No result ready!"
73+
raise ValueError(msg)
74+
return await self._result
75+
76+
5977
# explicit export of models for api-schemas
6078

6179
assert TaskResult # nosec
6280
assert TaskGet # nosec
6381
assert TaskStatus # nosec
6482

6583
__all__: tuple[str, ...] = (
84+
"ProgressMessage",
85+
"ProgressPercent",
6686
"TaskGet",
6787
"TaskId",
88+
"TaskProgress",
6889
"TaskResult",
6990
"TaskStatus",
70-
"TaskProgress",
71-
"ProgressPercent",
72-
"ProgressMessage",
7391
)

packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# pylint: disable=unused-argument
33

44
import asyncio
5-
from typing import Callable
5+
from collections.abc import Callable
66

77
import pytest
88
from aiohttp import ClientResponseError, web
@@ -36,7 +36,6 @@ def client(
3636
unused_tcp_port_factory: Callable,
3737
app: web.Application,
3838
) -> TestClient:
39-
4039
return event_loop.run_until_complete(
4140
aiohttp_client(app, server_kwargs={"port": unused_tcp_port_factory()})
4241
)

0 commit comments

Comments
 (0)