Skip to content

Commit fb8628a

Browse files
author
Andrei Neagu
committed
added in memory store to TasksManager
1 parent f1c60a8 commit fb8628a

File tree

9 files changed

+133
-71
lines changed

9 files changed

+133
-71
lines changed

packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ async def list_tasks(request: web.Request) -> web.Response:
2828
result_href=f"{request.app.router['get_task_result'].url_for(task_id=t.task_id)}",
2929
abort_href=f"{request.app.router['cancel_and_delete_task'].url_for(task_id=t.task_id)}",
3030
)
31-
for t in lrt_api.list_tasks(
31+
for t in await lrt_api.list_tasks(
3232
long_running_manager.tasks_manager,
3333
long_running_manager.get_task_context(request),
3434
)
@@ -41,7 +41,7 @@ async def get_task_status(request: web.Request) -> web.Response:
4141
path_params = parse_request_path_parameters_as(_PathParam, request)
4242
long_running_manager = get_long_running_manager(request.app)
4343

44-
task_status: TaskStatus = lrt_api.get_task_status(
44+
task_status: TaskStatus = await lrt_api.get_task_status(
4545
long_running_manager.tasks_manager,
4646
long_running_manager.get_task_context(request),
4747
path_params.task_id,

packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ async def list_tasks(
2929
request.url_for("cancel_and_delete_task", task_id=t.task_id)
3030
),
3131
)
32-
for t in lrt_api.list_tasks(
32+
for t in await lrt_api.list_tasks(
3333
long_running_manager.tasks_manager, task_context=None
3434
)
3535
]
@@ -51,7 +51,7 @@ async def get_task_status(
5151
],
5252
) -> TaskStatus:
5353
assert request # nosec
54-
return lrt_api.get_task_status(
54+
return await lrt_api.get_task_status(
5555
long_running_manager.tasks_manager, task_context=None, task_id=task_id
5656
)
5757

packages/service-library/src/servicelib/long_running_tasks/_store/__init__.py

Whitespace-only changes.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from abc import abstractmethod
2+
3+
from ..models import TaskId, TrackedTask
4+
5+
6+
class BaseStore:
7+
8+
@abstractmethod
9+
async def get(self, key: TaskId) -> TrackedTask | None:
10+
"""Retrieve a tracked task by its key."""
11+
12+
@abstractmethod
13+
async def set(self, key: TaskId, value: TrackedTask) -> None:
14+
"""Set a tracked task with its key."""
15+
16+
@abstractmethod
17+
async def list(self) -> list[TrackedTask]:
18+
"""List all tracked tasks."""
19+
20+
@abstractmethod
21+
async def delete(self, key: TaskId) -> None:
22+
"""Delete a tracked task by its key."""
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from ..models import TaskId, TrackedTask
2+
from .base import BaseStore
3+
4+
5+
class InMemoryStore(BaseStore):
6+
def __init__(self):
7+
self._store: dict[TaskId, TrackedTask] = {}
8+
9+
async def get(self, key: TaskId) -> TrackedTask | None:
10+
return self._store.get(key, None)
11+
12+
async def set(self, key: TaskId, value: TrackedTask) -> None:
13+
self._store[key] = value
14+
15+
async def list(self) -> list[TrackedTask]:
16+
return list(self._store.values())
17+
18+
async def delete(self, key: TaskId) -> None:
19+
self._store.pop(key, None)

packages/service-library/src/servicelib/long_running_tasks/lrt_api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ async def start_task(
4646
Returns:
4747
TaskId: the task unique identifier
4848
"""
49-
return tasks_manager.start_task(
49+
return await tasks_manager.start_task(
5050
registered_task_name,
5151
unique=unique,
5252
task_context=task_context,
@@ -56,17 +56,17 @@ async def start_task(
5656
)
5757

5858

59-
def list_tasks(
59+
async def list_tasks(
6060
tasks_manager: TasksManager, task_context: TaskContext | None
6161
) -> list[TaskBase]:
62-
return tasks_manager.list_tasks(with_task_context=task_context)
62+
return await tasks_manager.list_tasks(with_task_context=task_context)
6363

6464

65-
def get_task_status(
65+
async def get_task_status(
6666
tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId
6767
) -> TaskStatus:
6868
"""returns the status of a task"""
69-
return tasks_manager.get_task_status(
69+
return await tasks_manager.get_task_status(
7070
task_id=task_id, with_task_context=task_context
7171
)
7272

@@ -75,7 +75,7 @@ async def get_task_result(
7575
tasks_manager: TasksManager, task_context: TaskContext | None, task_id: TaskId
7676
) -> Any:
7777
try:
78-
task_result = tasks_manager.get_task_result(
78+
task_result = await tasks_manager.get_task_result(
7979
task_id, with_task_context=task_context
8080
)
8181
await tasks_manager.remove_task(

packages/service-library/src/servicelib/long_running_tasks/task.py

Lines changed: 38 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from servicelib.background_task import create_periodic_task
1616
from servicelib.logging_utils import log_catch
1717

18+
from ._store.base import BaseStore
19+
from ._store.in_memory import InMemoryStore
1820
from .errors import (
1921
TaskAlreadyRunningError,
2022
TaskCancelledError,
@@ -38,7 +40,6 @@
3840

3941
RegisteredTaskName: TypeAlias = str
4042
Namespace: TypeAlias = str
41-
TrackedTaskGroupDict: TypeAlias = dict[TaskId, TrackedTask]
4243
TaskContext: TypeAlias = dict[str, Any]
4344

4445

@@ -68,28 +69,28 @@ async def _await_task(task: asyncio.Task) -> None:
6869
await task
6970

7071

71-
def _get_tasks_to_remove(
72-
tracked_tasks: TrackedTaskGroupDict,
72+
async def _get_tasks_to_remove(
73+
tracked_tasks: BaseStore,
7374
stale_task_detect_timeout_s: PositiveFloat,
7475
) -> list[TaskId]:
7576
utc_now = datetime.datetime.now(tz=datetime.UTC)
7677

7778
tasks_to_remove: list[TaskId] = []
7879

79-
for task_id, tracked_task in tracked_tasks.items():
80+
for tracked_task in await tracked_tasks.list():
8081
if tracked_task.fire_and_forget:
8182
continue
8283

8384
if tracked_task.last_status_check is None:
8485
# the task just added or never received a poll request
8586
elapsed_from_start = (utc_now - tracked_task.started).seconds
8687
if elapsed_from_start > stale_task_detect_timeout_s:
87-
tasks_to_remove.append(task_id)
88+
tasks_to_remove.append(tracked_task.task_id)
8889
else:
8990
# the task status was already queried by the client
9091
elapsed_from_last_poll = (utc_now - tracked_task.last_status_check).seconds
9192
if elapsed_from_last_poll > stale_task_detect_timeout_s:
92-
tasks_to_remove.append(task_id)
93+
tasks_to_remove.append(tracked_task.task_id)
9394
return tasks_to_remove
9495

9596

@@ -103,10 +104,11 @@ def __init__(
103104
stale_task_check_interval: datetime.timedelta,
104105
stale_task_detect_timeout: datetime.timedelta,
105106
namespace: Namespace = _DEFAULT_NAMESPACE,
107+
# TODO: inject a Redis connection
106108
):
107109
self.namespace = namespace
108110
# Task groups: Every taskname maps to multiple asyncio.Task within TrackedTask model
109-
self._tracked_tasks: TrackedTaskGroupDict = {}
111+
self._tracked_tasks: BaseStore = InMemoryStore()
110112

111113
self.stale_task_check_interval = stale_task_check_interval
112114
self.stale_task_detect_timeout_s: PositiveFloat = (
@@ -125,7 +127,7 @@ async def setup(self) -> None:
125127
async def teardown(self) -> None:
126128
task_ids_to_remove: deque[TaskId] = deque()
127129

128-
for tracked_task in self._tracked_tasks.values():
130+
for tracked_task in await self._tracked_tasks.list():
129131
task_ids_to_remove.append(tracked_task.task_id)
130132

131133
for task_id in task_ids_to_remove:
@@ -155,7 +157,7 @@ async def _stale_tasks_monitor_worker(self) -> None:
155157
# Since we own the client, we assume (for now) this
156158
# will not be the case.
157159

158-
tasks_to_remove = _get_tasks_to_remove(
160+
tasks_to_remove = await _get_tasks_to_remove(
159161
self._tracked_tasks, self.stale_task_detect_timeout_s
160162
)
161163

@@ -169,25 +171,28 @@ async def _stale_tasks_monitor_worker(self) -> None:
169171
_logger.warning(
170172
"Removing stale task '%s' with status '%s'",
171173
task_id,
172-
self.get_task_status(task_id, with_task_context=None).model_dump_json(),
174+
(
175+
await self.get_task_status(task_id, with_task_context=None)
176+
).model_dump_json(),
173177
)
174178
await self.remove_task(
175179
task_id, with_task_context=None, reraise_errors=False
176180
)
177181

178-
def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBase]:
182+
async def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBase]:
179183
if not with_task_context:
180184
return [
181-
TaskBase(task_id=task.task_id) for task in self._tracked_tasks.values()
185+
TaskBase(task_id=task.task_id)
186+
for task in (await self._tracked_tasks.list())
182187
]
183188

184189
return [
185190
TaskBase(task_id=task.task_id)
186-
for task in self._tracked_tasks.values()
191+
for task in (await self._tracked_tasks.list())
187192
if task.task_context == with_task_context
188193
]
189194

190-
def _add_task(
195+
async def _add_task(
191196
self,
192197
task: asyncio.Task,
193198
task_progress: TaskProgress,
@@ -204,24 +209,24 @@ def _add_task(
204209
task_context=task_context,
205210
fire_and_forget=fire_and_forget,
206211
)
207-
self._tracked_tasks[task_id] = tracked_task
212+
await self._tracked_tasks.set(task_id, tracked_task)
208213

209214
return tracked_task
210215

211-
def _get_tracked_task(
216+
async def _get_tracked_task(
212217
self, task_id: TaskId, with_task_context: TaskContext | None
213218
) -> TrackedTask:
214-
if task_id not in self._tracked_tasks:
215-
raise TaskNotFoundError(task_id=task_id)
219+
task = await self._tracked_tasks.get(task_id)
216220

217-
task = self._tracked_tasks[task_id]
221+
if task is None:
222+
raise TaskNotFoundError(task_id=task_id)
218223

219224
if with_task_context and task.task_context != with_task_context:
220225
raise TaskNotFoundError(task_id=task_id)
221226

222227
return task
223228

224-
def get_task_status(
229+
async def get_task_status(
225230
self, task_id: TaskId, with_task_context: TaskContext | None
226231
) -> TaskStatus:
227232
"""
@@ -230,7 +235,9 @@ def get_task_status(
230235
231236
raises TaskNotFoundError if the task cannot be found
232237
"""
233-
tracked_task: TrackedTask = self._get_tracked_task(task_id, with_task_context)
238+
tracked_task: TrackedTask = await self._get_tracked_task(
239+
task_id, with_task_context
240+
)
234241
tracked_task.last_status_check = datetime.datetime.now(tz=datetime.UTC)
235242

236243
task = tracked_task.task
@@ -244,7 +251,7 @@ def get_task_status(
244251
}
245252
)
246253

247-
def get_task_result(
254+
async def get_task_result(
248255
self, task_id: TaskId, with_task_context: TaskContext | None
249256
) -> Any:
250257
"""
@@ -254,7 +261,7 @@ def get_task_result(
254261
raises TaskCancelledError if the task was cancelled
255262
raises TaskNotCompletedError if the task is not completed
256263
"""
257-
tracked_task = self._get_tracked_task(task_id, with_task_context)
264+
tracked_task = await self._get_tracked_task(task_id, with_task_context)
258265

259266
try:
260267
return tracked_task.task.result()
@@ -273,7 +280,7 @@ async def cancel_task(
273280
274281
raises TaskNotFoundError if the task cannot be found
275282
"""
276-
tracked_task = self._get_tracked_task(task_id, with_task_context)
283+
tracked_task = await self._get_tracked_task(task_id, with_task_context)
277284
await self._cancel_tracked_task(tracked_task.task, task_id, reraise_errors=True)
278285

279286
@staticmethod
@@ -317,7 +324,7 @@ async def remove_task(
317324
) -> None:
318325
"""cancels and removes task"""
319326
try:
320-
tracked_task = self._get_tracked_task(task_id, with_task_context)
327+
tracked_task = await self._get_tracked_task(task_id, with_task_context)
321328
except TaskNotFoundError:
322329
if reraise_errors:
323330
raise
@@ -327,13 +334,13 @@ async def remove_task(
327334
tracked_task.task, task_id, reraise_errors=reraise_errors
328335
)
329336
finally:
330-
del self._tracked_tasks[task_id]
337+
await self._tracked_tasks.delete(task_id)
331338

332339
def _get_task_id(self, task_name: str, *, is_unique: bool) -> TaskId:
333340
unique_part = "unique" if is_unique else f"{uuid4()}"
334341
return f"{self.namespace}.{task_name}.{unique_part}"
335342

336-
def start_task(
343+
async def start_task(
337344
self,
338345
registered_task_name: RegisteredTaskName,
339346
*,
@@ -358,9 +365,10 @@ def start_task(
358365
task_id = self._get_task_id(task_name, is_unique=unique)
359366

360367
# only one unique task can be running
361-
if unique and task_id in self._tracked_tasks:
368+
queried_task = await self._tracked_tasks.get(task_id)
369+
if unique and queried_task is not None:
362370
raise TaskAlreadyRunningError(
363-
task_name=task_name, managed_task=self._tracked_tasks[task_id]
371+
task_name=task_name, managed_task=queried_task
364372
)
365373

366374
task_progress = TaskProgress.create(task_id=task_id)
@@ -377,7 +385,7 @@ async def _progress_task(progress: TaskProgress, handler: TaskProtocol):
377385
_progress_task(task_progress, task), name=task_name
378386
)
379387

380-
tracked_task = self._add_task(
388+
tracked_task = await self._add_task(
381389
task=async_task,
382390
task_progress=task_progress,
383391
task_context=task_context or {},

0 commit comments

Comments
 (0)