Skip to content

Commit 01c58db

Browse files
author
Andrei Neagu
committed
refactor to use registered task
1 parent 77855dd commit 01c58db

File tree

15 files changed

+210
-126
lines changed

15 files changed

+210
-126
lines changed

packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
DEFAULT_STALE_TASK_DETECT_TIMEOUT,
1515
)
1616
from ...long_running_tasks.models import TaskGet
17-
from ...long_running_tasks.task import TaskContext, TaskProtocol
17+
from ...long_running_tasks.task import RegisteredTaskName, TaskContext
1818
from ..typing_extension import Handler
1919
from . import _routes
2020
from ._constants import (
@@ -45,7 +45,7 @@ def _create_task_name_from_request(request: web.Request) -> str:
4545
async def start_long_running_task(
4646
# NOTE: positional argument are suffixed with "_" to avoid name conflicts with "task_kwargs" keys
4747
request_: web.Request,
48-
task_: TaskProtocol,
48+
registerd_task_name: RegisteredTaskName,
4949
*,
5050
fire_and_forget: bool = False,
5151
task_context: TaskContext,
@@ -56,7 +56,7 @@ async def start_long_running_task(
5656
task_id = None
5757
try:
5858
task_id = long_running_manager.tasks_manager.start_task(
59-
task_,
59+
registerd_task_name,
6060
fire_and_forget=fire_and_forget,
6161
task_context=task_context,
6262
task_name=task_name,

packages/service-library/src/servicelib/long_running_tasks/errors.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@ class BaseLongRunningError(OsparcErrorMixin, Exception):
55
"""base exception for this module"""
66

77

8+
class TaskNotRegisteredError(BaseLongRunningError):
9+
msg_template: str = (
10+
"notask with task_name='{task_name}' was found in the task registry. "
11+
"Make sure it's registered before starting it."
12+
)
13+
14+
815
class TaskAlreadyRunningError(BaseLongRunningError):
916
msg_template: str = "{task_name} must be unique, found: '{managed_task}'"
1017

packages/service-library/src/servicelib/long_running_tasks/task.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import urllib.parse
77
from collections import deque
88
from contextlib import suppress
9-
from typing import Any, Final, Protocol, TypeAlias
9+
from typing import Any, ClassVar, Final, Protocol, TypeAlias
1010
from uuid import uuid4
1111

1212
from models_library.api_schemas_long_running_tasks.base import TaskProgress
@@ -21,6 +21,7 @@
2121
TaskExceptionError,
2222
TaskNotCompletedError,
2323
TaskNotFoundError,
24+
TaskNotRegisteredError,
2425
)
2526
from .models import TaskId, TaskStatus, TrackedTask
2627

@@ -35,6 +36,7 @@
3536
seconds=1
3637
).total_seconds()
3738

39+
RegisteredTaskName: TypeAlias = str
3840
Namespace: TypeAlias = str
3941
TrackedTaskGroupDict: TypeAlias = dict[TaskId, TrackedTask]
4042
TaskContext: TypeAlias = dict[str, Any]
@@ -49,6 +51,19 @@ async def __call__(
4951
def __name__(self) -> str: ...
5052

5153

54+
class TaskRegistry:
55+
REGISTERED_TASKS: ClassVar[dict[RegisteredTaskName, TaskProtocol]] = {}
56+
57+
@classmethod
58+
def register(cls, task: TaskProtocol) -> None:
59+
cls.REGISTERED_TASKS[task.__name__] = task
60+
61+
@classmethod
62+
def unregister(cls, task: TaskProtocol) -> None:
63+
if task.__name__ in cls.REGISTERED_TASKS:
64+
del cls.REGISTERED_TASKS[task.__name__]
65+
66+
5267
async def _await_task(task: asyncio.Task) -> None:
5368
await task
5469

@@ -318,7 +333,7 @@ def _get_task_id(self, task_name: str, *, is_unique: bool) -> TaskId:
318333

319334
def start_task(
320335
self,
321-
task: TaskProtocol,
336+
registered_task_name: RegisteredTaskName,
322337
*,
323338
unique: bool = False,
324339
task_context: TaskContext | None = None,
@@ -351,6 +366,10 @@ def start_task(
351366
Returns:
352367
TaskId: the task unique identifier
353368
"""
369+
if registered_task_name not in TaskRegistry.REGISTERED_TASKS:
370+
raise TaskNotRegisteredError(task_name=registered_task_name)
371+
372+
task = TaskRegistry.REGISTERED_TASKS[registered_task_name]
354373

355374
# NOTE: If not task name is given, it will be composed of the handler's module and it's name
356375
# to keep the urls shorter and more meaningful.

packages/service-library/tests/aiohttp/long_running_tasks/conftest.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,15 @@
1919
TaskProgress,
2020
TaskStatus,
2121
)
22-
from servicelib.long_running_tasks.task import TaskContext
22+
from servicelib.long_running_tasks.task import TaskContext, TaskRegistry
2323
from tenacity.asyncio import AsyncRetrying
2424
from tenacity.retry import retry_if_exception_type
2525
from tenacity.stop import stop_after_delay
2626
from tenacity.wait import wait_fixed
2727

2828

2929
async def _string_list_task(
30-
task_progress: TaskProgress,
30+
progress: TaskProgress,
3131
num_strings: int,
3232
sleep_time: float,
3333
fail: bool,
@@ -36,7 +36,7 @@ async def _string_list_task(
3636
for index in range(num_strings):
3737
generated_strings.append(f"{index}")
3838
await asyncio.sleep(sleep_time)
39-
task_progress.update(message="generated item", percent=index / num_strings)
39+
progress.update(message="generated item", percent=index / num_strings)
4040
if fail:
4141
msg = "We were asked to fail!!"
4242
raise RuntimeError(msg)
@@ -47,6 +47,9 @@ async def _string_list_task(
4747
)
4848

4949

50+
TaskRegistry.register(_string_list_task)
51+
52+
5053
@pytest.fixture
5154
def task_context(faker: Faker) -> TaskContext:
5255
return {"user_id": faker.pyint(), "product": faker.pystr()}
@@ -73,7 +76,7 @@ async def generate_list_strings(request: web.Request) -> web.Response:
7376
query_params = parse_request_query_parameters_as(_LongTaskQueryParams, request)
7477
return await long_running_tasks.server.start_long_running_task(
7578
request,
76-
_string_list_task,
79+
_string_list_task.__name__,
7780
num_strings=query_params.num_strings,
7881
sleep_time=query_params.sleep_time,
7982
fail=query_params.fail,

packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ async def test_workflow(
9797
# now get the result
9898
result_url = client.app.router["get_task_result"].url_for(task_id=task_id)
9999
result = await client.get(f"{result_url}")
100-
task_result, error = await assert_status(result, status.HTTP_201_CREATED)
100+
task_result, error = await assert_status(result, status.HTTP_200_OK)
101101
assert task_result
102102
assert not error
103103
assert task_result == [f"{x}" for x in range(10)]

packages/service-library/tests/aiohttp/long_running_tasks/test_long_running_tasks_with_task_context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ async def test_get_task_result(
151151
await assert_status(resp, status.HTTP_404_NOT_FOUND)
152152
# calling with context should find the task
153153
resp = await client_with_task_context.get(f"{result_url.with_query(task_context)}")
154-
await assert_status(resp, status.HTTP_201_CREATED)
154+
await assert_status(resp, status.HTTP_200_OK)
155155

156156

157157
async def test_cancel_task(

packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import asyncio
1313
import json
1414
from collections.abc import AsyncIterator, Awaitable, Callable
15-
from typing import Final
15+
from typing import Annotated, Final
1616

1717
import pytest
1818
from asgi_lifespan import LifespanManager
@@ -31,7 +31,7 @@
3131
TaskProgress,
3232
TaskStatus,
3333
)
34-
from servicelib.long_running_tasks.task import TaskContext
34+
from servicelib.long_running_tasks.task import TaskContext, TaskRegistry
3535
from tenacity.asyncio import AsyncRetrying
3636
from tenacity.retry import retry_if_exception_type
3737
from tenacity.stop import stop_after_delay
@@ -42,7 +42,7 @@
4242

4343

4444
async def _string_list_task(
45-
task_progress: TaskProgress,
45+
progress: TaskProgress,
4646
num_strings: int,
4747
sleep_time: float,
4848
fail: bool,
@@ -51,14 +51,17 @@ async def _string_list_task(
5151
for index in range(num_strings):
5252
generated_strings.append(f"{index}")
5353
await asyncio.sleep(sleep_time)
54-
task_progress.update(message="generated item", percent=index / num_strings)
54+
progress.update(message="generated item", percent=index / num_strings)
5555
if fail:
5656
msg = "We were asked to fail!!"
5757
raise RuntimeError(msg)
5858

5959
return generated_strings
6060

6161

62+
TaskRegistry.register(_string_list_task)
63+
64+
6265
@pytest.fixture
6366
def server_routes() -> APIRouter:
6467
routes = APIRouter()
@@ -69,18 +72,18 @@ def server_routes() -> APIRouter:
6972
async def create_string_list_task(
7073
num_strings: int,
7174
sleep_time: float,
75+
long_running_manager: Annotated[
76+
FastAPILongRunningManager, Depends(get_long_running_manager)
77+
],
78+
*,
7279
fail: bool = False,
73-
long_running_manager: FastAPILongRunningManager = Depends(
74-
get_long_running_manager
75-
),
7680
) -> TaskId:
77-
task_id = long_running_manager.tasks_manager.start_task(
78-
_string_list_task,
81+
return long_running_manager.tasks_manager.start_task(
82+
_string_list_task.__name__,
7983
num_strings=num_strings,
8084
sleep_time=sleep_time,
8185
fail=fail,
8286
)
83-
return task_id
8487

8588
return routes
8689

packages/service-library/tests/fastapi/long_running_tasks/test_long_running_tasks_context_manager.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TaskId,
2626
TaskProgress,
2727
)
28+
from servicelib.long_running_tasks.task import TaskRegistry
2829

2930
TASK_SLEEP_INTERVAL: Final[PositiveFloat] = 0.1
3031

@@ -38,17 +39,25 @@ async def _assert_task_removed(
3839
assert result.status_code == status.HTTP_404_NOT_FOUND
3940

4041

41-
async def a_test_task(task_progress: TaskProgress) -> int:
42+
async def a_test_task(progress: TaskProgress) -> int:
43+
_ = progress
4244
await asyncio.sleep(TASK_SLEEP_INTERVAL)
4345
return 42
4446

4547

46-
async def a_failing_test_task(task_progress: TaskProgress) -> None:
48+
TaskRegistry.register(a_test_task)
49+
50+
51+
async def a_failing_test_task(progress: TaskProgress) -> None:
52+
_ = progress
4753
await asyncio.sleep(TASK_SLEEP_INTERVAL)
4854
msg = "I am failing as requested"
4955
raise RuntimeError(msg)
5056

5157

58+
TaskRegistry.register(a_failing_test_task)
59+
60+
5261
@pytest.fixture
5362
def user_routes() -> APIRouter:
5463
router = APIRouter()
@@ -59,15 +68,17 @@ async def create_task_user_defined_route(
5968
FastAPILongRunningManager, Depends(get_long_running_manager)
6069
],
6170
) -> TaskId:
62-
return long_running_manager.tasks_manager.start_task(task=a_test_task)
71+
return long_running_manager.tasks_manager.start_task(a_test_task.__name__)
6372

6473
@router.get("/api/failing", status_code=status.HTTP_200_OK)
6574
async def create_task_which_fails(
6675
long_running_manager: Annotated[
6776
FastAPILongRunningManager, Depends(get_long_running_manager)
6877
],
6978
) -> TaskId:
70-
return long_running_manager.tasks_manager.start_task(task=a_failing_test_task)
79+
return long_running_manager.tasks_manager.start_task(
80+
a_failing_test_task.__name__
81+
)
7182

7283
return router
7384

0 commit comments

Comments
 (0)