Skip to content

Commit b57923f

Browse files
author
Andrei Neagu
committed
updated interface
1 parent 61a5c1e commit b57923f

File tree

4 files changed

+23
-28
lines changed

4 files changed

+23
-28
lines changed

packages/service-library/src/servicelib/long_running_tasks/_store/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,12 @@ async def delete_task_data(self, task_id: TaskId) -> None:
2323

2424
@abstractmethod
2525
async def set_as_cancelled(
26-
self, task_id: TaskId, with_task_context: TaskContext | None
26+
self, task_id: TaskId, with_task_context: TaskContext
2727
) -> None:
2828
"""Mark a tracked task as cancelled."""
2929

3030
@abstractmethod
31-
async def get_cancelled(self) -> dict[TaskId, TaskContext | None]:
31+
async def get_cancelled(self) -> dict[TaskId, TaskContext]:
3232
"""Get cancelled tasks."""
3333

3434
@abstractmethod

packages/service-library/src/servicelib/long_running_tasks/_store/in_memory.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ def __init__(self, *args, **kwargs):
77
_ = args
88
_ = kwargs
99
self._tasks_data: dict[TaskId, TaskData] = {}
10-
self._cancelled_tasks: dict[TaskId, TaskContext | None] = {}
10+
self._cancelled_tasks: dict[TaskId, TaskContext] = {}
1111

1212
async def setup(self) -> None:
1313
pass
@@ -28,9 +28,9 @@ async def delete_task_data(self, task_id: TaskId) -> None:
2828
self._tasks_data.pop(task_id, None)
2929

3030
async def set_as_cancelled(
31-
self, task_id: TaskId, with_task_context: TaskContext | None
31+
self, task_id: TaskId, with_task_context: TaskContext
3232
) -> None:
3333
self._cancelled_tasks[task_id] = with_task_context
3434

35-
async def get_cancelled(self) -> dict[TaskId, TaskContext | None]:
35+
async def get_cancelled(self) -> dict[TaskId, TaskContext]:
3636
return self._cancelled_tasks

packages/service-library/src/servicelib/long_running_tasks/_store/redis.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,19 +64,16 @@ async def delete_task_data(self, task_id: TaskId) -> None:
6464
await self.redis.hdel(self._get_redis_hash_key(STORE_TYPE_TASK_DATA), task_id) # type: ignore[misc]
6565

6666
async def set_as_cancelled(
67-
self, task_id: TaskId, with_task_context: TaskContext | None
67+
self, task_id: TaskId, with_task_context: TaskContext
6868
) -> None:
6969
await self.redis.hset(
7070
self._get_redis_hash_key(STORE_TYPE_CANCELLED_TASKS),
7171
task_id,
7272
json_dumps(with_task_context),
7373
) # type: ignore[misc]
7474

75-
async def get_cancelled(self) -> dict[TaskId, TaskContext | None]:
75+
async def get_cancelled(self) -> dict[TaskId, TaskContext]:
7676
result: dict[str, str | None] = await self.redis.hgetall(
7777
self._get_redis_hash_key(STORE_TYPE_CANCELLED_TASKS)
7878
) # type: ignore[misc]
79-
return {
80-
task_id: (json_loads(context) if context else None)
81-
for task_id, context in result.items()
82-
}
79+
return {task_id: json_loads(context) for task_id, context in result.items()}

packages/service-library/src/servicelib/long_running_tasks/task.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import logging
66
import traceback
77
import urllib.parse
8-
from collections import deque
98
from contextlib import suppress
109
from typing import Any, ClassVar, Final, Protocol, TypeAlias
1110
from uuid import uuid4
@@ -73,10 +72,10 @@ async def _await_task(task: asyncio.Task) -> None:
7372
async def _get_tasks_to_remove(
7473
tracked_tasks: BaseStore,
7574
stale_task_detect_timeout_s: PositiveFloat,
76-
) -> list[tuple[TaskId, TaskContext | None]]:
75+
) -> list[tuple[TaskId, TaskContext]]:
7776
utc_now = datetime.datetime.now(tz=datetime.UTC)
7877

79-
tasks_to_remove: list[tuple[TaskId, TaskContext | None]] = []
78+
tasks_to_remove: list[tuple[TaskId, TaskContext]] = []
8079

8180
for tracked_task in await tracked_tasks.list_tasks_data():
8281
if tracked_task.fire_and_forget:
@@ -142,14 +141,12 @@ async def setup(self) -> None:
142141
)
143142

144143
async def teardown(self) -> None:
145-
task_ids_to_remove: deque[TaskId] = deque()
146144

147145
for tracked_task in await self._tasks_data.list_tasks_data():
148-
task_ids_to_remove.append(tracked_task.task_id)
149-
150-
for task_id in task_ids_to_remove:
151146
# when closing we do not care about pending errors
152-
await self.remove_task(task_id, None, reraise_errors=False)
147+
await self.remove_task(
148+
tracked_task.task_id, tracked_task.task_context, reraise_errors=False
149+
)
153150

154151
if self._stale_tasks_monitor_task:
155152
with log_catch(_logger, reraise=False):
@@ -248,7 +245,7 @@ async def _add_task(
248245
return task_data
249246

250247
async def _get_tracked_task(
251-
self, task_id: TaskId, with_task_context: TaskContext | None
248+
self, task_id: TaskId, with_task_context: TaskContext
252249
) -> TaskData:
253250
task_data = await self._tasks_data.get_task_data(task_id)
254251

@@ -261,7 +258,7 @@ async def _get_tracked_task(
261258
return task_data
262259

263260
async def get_task_status(
264-
self, task_id: TaskId, with_task_context: TaskContext | None
261+
self, task_id: TaskId, with_task_context: TaskContext
265262
) -> TaskStatus:
266263
"""
267264
returns: the status of the task, along with updates
@@ -285,7 +282,7 @@ async def get_task_status(
285282
)
286283

287284
async def get_task_result(
288-
self, task_id: TaskId, with_task_context: TaskContext | None
285+
self, task_id: TaskId, with_task_context: TaskContext
289286
) -> Any:
290287
"""
291288
returns: the result of the task
@@ -306,7 +303,7 @@ async def get_task_result(
306303
raise TaskCancelledError(task_id=task_id) from exc
307304

308305
async def cancel_task(
309-
self, task_id: TaskId, with_task_context: TaskContext | None
306+
self, task_id: TaskId, with_task_context: TaskContext
310307
) -> None:
311308
"""
312309
cancels the task
@@ -354,7 +351,7 @@ async def _cancel_tracked_task(
354351
async def remove_task(
355352
self,
356353
task_id: TaskId,
357-
with_task_context: TaskContext | None,
354+
with_task_context: TaskContext,
358355
*,
359356
reraise_errors: bool = True,
360357
) -> None:
@@ -382,7 +379,7 @@ def _get_task_id(self, task_name: str, *, is_unique: bool) -> TaskId:
382379
async def _update_progress(
383380
self,
384381
task_id: TaskId,
385-
task_context: TaskContext | None,
382+
task_context: TaskContext,
386383
task_progress: TaskProgress,
387384
) -> None:
388385
tracked_data = await self._get_tracked_task(task_id, task_context)
@@ -420,10 +417,11 @@ async def start_task(
420417
task_name=task_name, managed_task=queried_task
421418
)
422419

420+
context_to_use = task_context or {}
423421
task_progress = TaskProgress.create(task_id=task_id)
424422
# set update callback
425423
task_progress.set_update_callback(
426-
functools.partial(self._update_progress, task_id, task_context)
424+
functools.partial(self._update_progress, task_id, context_to_use)
427425
)
428426

429427
# bind the task with progress 0 and 1
@@ -441,7 +439,7 @@ async def _progress_task(progress: TaskProgress, handler: TaskProtocol):
441439
tracked_task = await self._add_task(
442440
task=async_task,
443441
task_progress=task_progress,
444-
task_context=task_context or {},
442+
task_context=context_to_use,
445443
fire_and_forget=fire_and_forget,
446444
task_id=task_id,
447445
)
@@ -452,10 +450,10 @@ async def _progress_task(progress: TaskProgress, handler: TaskProtocol):
452450
__all__: tuple[str, ...] = (
453451
"TaskAlreadyRunningError",
454452
"TaskCancelledError",
453+
"TaskData",
455454
"TaskId",
456455
"TaskProgress",
457456
"TaskProtocol",
458457
"TaskStatus",
459458
"TasksManager",
460-
"TaskData",
461459
)

0 commit comments

Comments
 (0)