Skip to content

Commit 12fa12c

Browse files
GitHKAndrei Neagu
andauthored
🎨 Makes removal of long running tasks faster (#8350)
Co-authored-by: Andrei Neagu <[email protected]>
1 parent 62565ba commit 12fa12c

File tree

17 files changed

+299
-262
lines changed

17 files changed

+299
-262
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# pylint: disable=protected-access
2+
3+
import pytest
4+
from fastapi import FastAPI
5+
from servicelib.long_running_tasks.errors import TaskNotFoundError
6+
from servicelib.long_running_tasks.manager import (
7+
LongRunningManager,
8+
)
9+
from servicelib.long_running_tasks.models import TaskContext
10+
from servicelib.long_running_tasks.task import TaskId
11+
from tenacity import (
12+
AsyncRetrying,
13+
retry_if_not_exception_type,
14+
stop_after_delay,
15+
wait_fixed,
16+
)
17+
18+
19+
def get_fastapi_long_running_manager(app: FastAPI) -> LongRunningManager:
20+
manager = app.state.long_running_manager
21+
assert isinstance(manager, LongRunningManager)
22+
return manager
23+
24+
25+
async def assert_task_is_no_longer_present(
26+
manager: LongRunningManager, task_id: TaskId, task_context: TaskContext
27+
) -> None:
28+
async for attempt in AsyncRetrying(
29+
reraise=True,
30+
wait=wait_fixed(0.1),
31+
stop=stop_after_delay(60),
32+
retry=retry_if_not_exception_type(TaskNotFoundError),
33+
):
34+
with attempt: # noqa: SIM117
35+
with pytest.raises(TaskNotFoundError):
36+
# use internals to detirmine when it's no longer here
37+
await manager._tasks_manager._get_tracked_task( # noqa: SLF001
38+
task_id, task_context
39+
)

packages/service-library/src/servicelib/aiohttp/long_running_tasks/_routes.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
from typing import Annotated, Any
1+
from typing import Any
22

33
from aiohttp import web
4-
from models_library.rest_base import RequestParameters
5-
from pydantic import BaseModel, Field
4+
from pydantic import BaseModel
65

76
from ...aiohttp import status
87
from ...long_running_tasks import lrt_api
98
from ...long_running_tasks.models import TaskGet, TaskId
109
from ..requests_validation import (
1110
parse_request_path_parameters_as,
12-
parse_request_query_parameters_as,
1311
)
1412
from ..rest_responses import create_data_response
1513
from ._manager import get_long_running_manager
@@ -69,29 +67,15 @@ async def get_task_result(request: web.Request) -> web.Response | Any:
6967
)
7068

7169

72-
class _RemoveTaskQueryParams(RequestParameters):
73-
wait_for_removal: Annotated[
74-
bool,
75-
Field(
76-
description=(
77-
"when True waits for the task to be removed "
78-
"completly instead of returning immediately"
79-
)
80-
),
81-
] = True
82-
83-
8470
@routes.delete("/{task_id}", name="remove_task")
8571
async def remove_task(request: web.Request) -> web.Response:
8672
path_params = parse_request_path_parameters_as(_PathParam, request)
87-
query_params = parse_request_query_parameters_as(_RemoveTaskQueryParams, request)
8873
long_running_manager = get_long_running_manager(request.app)
8974

9075
await lrt_api.remove_task(
9176
long_running_manager.rpc_client,
9277
long_running_manager.lrt_namespace,
9378
long_running_manager.get_task_context(request),
9479
path_params.task_id,
95-
wait_for_removal=query_params.wait_for_removal,
9680
)
9781
return web.json_response(status=status.HTTP_204_NO_CONTENT)

packages/service-library/src/servicelib/aiohttp/long_running_tasks/_server.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,6 @@ async def start_long_running_task(
108108
long_running_manager.lrt_namespace,
109109
task_context,
110110
task_id,
111-
wait_for_removal=True,
112111
)
113112
raise
114113

packages/service-library/src/servicelib/fastapi/long_running_tasks/_routes.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Annotated, Any
22

3-
from fastapi import APIRouter, Depends, Query, Request, status
3+
from fastapi import APIRouter, Depends, Request, status
44

55
from ...long_running_tasks import lrt_api
66
from ...long_running_tasks.models import TaskGet, TaskId, TaskResult, TaskStatus
@@ -101,22 +101,11 @@ async def remove_task(
101101
FastAPILongRunningManager, Depends(get_long_running_manager)
102102
],
103103
task_id: TaskId,
104-
*,
105-
wait_for_removal: Annotated[
106-
bool,
107-
Query(
108-
description=(
109-
"when True waits for the task to be removed "
110-
"completly instead of returning immediately"
111-
),
112-
),
113-
] = True,
114104
) -> None:
115105
assert request # nosec
116106
await lrt_api.remove_task(
117107
long_running_manager.rpc_client,
118108
long_running_manager.lrt_namespace,
119109
long_running_manager.get_task_context(request),
120110
task_id=task_id,
121-
wait_for_removal=wait_for_removal,
122111
)

packages/service-library/src/servicelib/long_running_tasks/_redis_store.py

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from ..redis._client import RedisClientSDK
99
from ..redis._utils import handle_redis_returns_union_types
1010
from ..utils import limited_gather
11-
from .models import LRTNamespace, TaskContext, TaskData, TaskId
11+
from .models import LRTNamespace, TaskData, TaskId
1212

1313
_STORE_TYPE_TASK_DATA: Final[str] = "TD"
14-
_STORE_TYPE_CANCELLED_TASKS: Final[str] = "CT"
15-
_LIST_CONCURRENCY: Final[int] = 2
14+
_LIST_CONCURRENCY: Final[int] = 3
15+
_MARKED_FOR_REMOVAL_FIELD: Final[str] = "marked_for_removal"
1616

1717

1818
def _to_redis_hash_mapping(data: dict[str, Any]) -> dict[str, str]:
@@ -52,11 +52,6 @@ def _get_redis_key_task_data_match(self) -> str:
5252
def _get_redis_task_data_key(self, task_id: TaskId) -> str:
5353
return f"{self.namespace}:{_STORE_TYPE_TASK_DATA}:{task_id}"
5454

55-
def _get_key_to_remove(self) -> str:
56-
return f"{self.namespace}:{_STORE_TYPE_CANCELLED_TASKS}"
57-
58-
# TaskData
59-
6055
async def get_task_data(self, task_id: TaskId) -> TaskData | None:
6156
result: dict[str, Any] = await handle_redis_returns_union_types(
6257
self._redis.hgetall(
@@ -115,24 +110,18 @@ async def delete_task_data(self, task_id: TaskId) -> None:
115110
self._redis.delete(self._get_redis_task_data_key(task_id))
116111
)
117112

118-
# to cancel
119-
120-
async def mark_task_for_removal(
121-
self, task_id: TaskId, with_task_context: TaskContext
122-
) -> None:
113+
async def mark_for_removal(self, task_id: TaskId) -> None:
123114
await handle_redis_returns_union_types(
124115
self._redis.hset(
125-
self._get_key_to_remove(), task_id, json_dumps(with_task_context)
116+
self._get_redis_task_data_key(task_id),
117+
mapping=_to_redis_hash_mapping({_MARKED_FOR_REMOVAL_FIELD: True}),
126118
)
127119
)
128120

129-
async def completed_task_removal(self, task_id: TaskId) -> None:
130-
await handle_redis_returns_union_types(
131-
self._redis.hdel(self._get_key_to_remove(), task_id)
132-
)
133-
134-
async def list_tasks_to_remove(self) -> dict[TaskId, TaskContext]:
135-
result: dict[str, str | None] = await handle_redis_returns_union_types(
136-
self._redis.hgetall(self._get_key_to_remove())
121+
async def is_marked_for_removal(self, task_id: TaskId) -> bool:
122+
result = await handle_redis_returns_union_types(
123+
self._redis.hget(
124+
self._get_redis_task_data_key(task_id), _MARKED_FOR_REMOVAL_FIELD
125+
)
137126
)
138-
return {task_id: json_loads(context) for task_id, context in result.items()}
127+
return False if result is None else json_loads(result)

packages/service-library/src/servicelib/long_running_tasks/_rpc_client.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -118,26 +118,13 @@ async def remove_task(
118118
*,
119119
task_context: TaskContext,
120120
task_id: TaskId,
121-
wait_for_removal: bool,
122-
cancellation_timeout: timedelta | None,
123121
) -> None:
124-
timeout_s = (
125-
None
126-
if cancellation_timeout is None
127-
else int(cancellation_timeout.total_seconds())
128-
)
129-
130-
# NOTE: task always gets cancelled even if not waiting for it
131-
# request will return immediatlye, no need to wait so much
132-
if wait_for_removal is False:
133-
timeout_s = _RPC_TIMEOUT_SHORT_REQUESTS
134122

135123
result = await rabbitmq_rpc_client.request(
136124
get_rabbit_namespace(namespace),
137125
TypeAdapter(RPCMethodName).validate_python("remove_task"),
138126
task_context=task_context,
139127
task_id=task_id,
140-
wait_for_removal=wait_for_removal,
141-
timeout_s=timeout_s,
128+
timeout_s=_RPC_TIMEOUT_SHORT_REQUESTS,
142129
)
143130
assert result is None # nosec

packages/service-library/src/servicelib/long_running_tasks/_rpc_server.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def get_task_result(
8888
await long_running_manager.tasks_manager.remove_task(
8989
task_id,
9090
with_task_context=task_context,
91-
wait_for_removal=True,
91+
wait_for_removal=False,
9292
)
9393

9494

@@ -98,10 +98,7 @@ async def remove_task(
9898
*,
9999
task_context: TaskContext,
100100
task_id: TaskId,
101-
wait_for_removal: bool,
102101
) -> None:
103102
await long_running_manager.tasks_manager.remove_task(
104-
task_id,
105-
with_task_context=task_context,
106-
wait_for_removal=wait_for_removal,
103+
task_id, with_task_context=task_context, wait_for_removal=False
107104
)

packages/service-library/src/servicelib/long_running_tasks/lrt_api.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from datetime import timedelta
21
from typing import Any
32

43
from ..rabbitmq._client_rpc import RabbitMQRPCClient
@@ -103,9 +102,6 @@ async def remove_task(
103102
lrt_namespace: LRTNamespace,
104103
task_context: TaskContext,
105104
task_id: TaskId,
106-
*,
107-
wait_for_removal: bool,
108-
cancellation_timeout: timedelta | None = None,
109105
) -> None:
110106
"""cancels and removes a task
111107
@@ -116,6 +112,4 @@ async def remove_task(
116112
lrt_namespace,
117113
task_id=task_id,
118114
task_context=task_context,
119-
wait_for_removal=wait_for_removal,
120-
cancellation_timeout=cancellation_timeout,
121115
)

packages/service-library/src/servicelib/long_running_tasks/models.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class TaskData(BaseModel):
5050
task_id: str
5151
task_progress: TaskProgress
5252
# NOTE: this context lifetime is with the tracked task (similar to aiohttp storage concept)
53-
task_context: dict[str, Any]
53+
task_context: TaskContext
5454
fire_and_forget: Annotated[
5555
bool,
5656
Field(
@@ -78,6 +78,10 @@ class TaskData(BaseModel):
7878
result_field: Annotated[
7979
ResultField | None, Field(description="the result of the task")
8080
] = None
81+
marked_for_removal: Annotated[
82+
bool,
83+
Field(description=("if True, indicates the task is marked for removal")),
84+
] = False
8185

8286
model_config = ConfigDict(
8387
arbitrary_types_allowed=True,

packages/service-library/src/servicelib/long_running_tasks/task.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ async def _stale_tasks_monitor(self) -> None:
282282
# we just print the status from where one can infer the above
283283
with suppress(TaskNotFoundError):
284284
task_status = await self.get_task_status(
285-
task_id, with_task_context=task_context
285+
task_id, with_task_context=task_context, exclude_to_remove=False
286286
)
287287
with log_context(
288288
_logger,
@@ -300,11 +300,17 @@ async def _cancelled_tasks_removal(self) -> None:
300300
"""
301301
self._started_event_task_cancelled_tasks_removal.set()
302302

303-
to_remove = await self._tasks_data.list_tasks_to_remove()
304-
for task_id in to_remove:
305-
await self._attempt_to_remove_local_task(task_id)
303+
tasks_data = await self._tasks_data.list_tasks_data()
304+
await limited_gather(
305+
*(
306+
self._attempt_to_remove_local_task(x.task_id)
307+
for x in tasks_data
308+
if x.marked_for_removal is True
309+
),
310+
limit=_PARALLEL_TASKS_CANCELLATION,
311+
)
306312

307-
async def _tasks_monitor(self) -> None:
313+
async def _tasks_monitor(self) -> None: # noqa: C901
308314
"""
309315
A task which monitors locally running tasks and updates their status
310316
in the Redis store when they are done.
@@ -396,12 +402,14 @@ async def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBa
396402
return [
397403
TaskBase(task_id=task.task_id)
398404
for task in (await self._tasks_data.list_tasks_data())
405+
if task.marked_for_removal is False
399406
]
400407

401408
return [
402409
TaskBase(task_id=task.task_id)
403410
for task in (await self._tasks_data.list_tasks_data())
404411
if task.task_context == with_task_context
412+
and task.marked_for_removal is False
405413
]
406414

407415
async def _get_tracked_task(
@@ -418,14 +426,21 @@ async def _get_tracked_task(
418426
return task_data
419427

420428
async def get_task_status(
421-
self, task_id: TaskId, with_task_context: TaskContext
429+
self,
430+
task_id: TaskId,
431+
with_task_context: TaskContext,
432+
*,
433+
exclude_to_remove: bool = True,
422434
) -> TaskStatus:
423435
"""
424436
returns: the status of the task, along with updates
425437
form the progress
426438
427439
raises TaskNotFoundError if the task cannot be found
428440
"""
441+
if exclude_to_remove and await self._tasks_data.is_marked_for_removal(task_id):
442+
raise TaskNotFoundError(task_id=task_id)
443+
429444
task_data = await self._get_tracked_task(task_id, with_task_context)
430445

431446
await self._tasks_data.update_task_data(
@@ -460,6 +475,9 @@ async def get_task_result(
460475
raises TaskNotFoundError if the task cannot be found
461476
raises TaskNotCompletedError if the task is not completed
462477
"""
478+
if await self._tasks_data.is_marked_for_removal(task_id):
479+
raise TaskNotFoundError(task_id=task_id)
480+
463481
tracked_task = await self._get_tracked_task(task_id, with_task_context)
464482

465483
if not tracked_task.is_done or tracked_task.result_field is None:
@@ -473,7 +491,6 @@ async def _attempt_to_remove_local_task(self, task_id: TaskId) -> None:
473491
task_to_cancel = self._created_tasks.pop(task_id, None)
474492
if task_to_cancel is not None:
475493
await cancel_wait_task(task_to_cancel)
476-
await self._tasks_data.completed_task_removal(task_id)
477494
await self._tasks_data.delete_task_data(task_id)
478495

479496
async def remove_task(
@@ -487,11 +504,12 @@ async def remove_task(
487504
cancels and removes task
488505
raises TaskNotFoundError if the task cannot be found
489506
"""
507+
if await self._tasks_data.is_marked_for_removal(task_id):
508+
raise TaskNotFoundError(task_id=task_id)
509+
490510
tracked_task = await self._get_tracked_task(task_id, with_task_context)
491511

492-
await self._tasks_data.mark_task_for_removal(
493-
tracked_task.task_id, tracked_task.task_context
494-
)
512+
await self._tasks_data.mark_for_removal(tracked_task.task_id)
495513

496514
if not wait_for_removal:
497515
return

0 commit comments

Comments
 (0)