Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# pylint: disable=protected-access

import pytest
from fastapi import FastAPI
from servicelib.long_running_tasks.errors import TaskNotFoundError
from servicelib.long_running_tasks.manager import (
LongRunningManager,
)
from servicelib.long_running_tasks.models import TaskContext
from servicelib.long_running_tasks.task import TaskId
from tenacity import (
AsyncRetrying,
retry_if_not_exception_type,
stop_after_delay,
wait_fixed,
)


def get_fastapi_long_running_manager(app: FastAPI) -> LongRunningManager:
manager = app.state.long_running_manager
assert isinstance(manager, LongRunningManager)
return manager


async def assert_task_is_no_longer_present(
manager: LongRunningManager, task_id: TaskId, task_context: TaskContext
) -> None:
async for attempt in AsyncRetrying(
reraise=True,
wait=wait_fixed(0.1),
stop=stop_after_delay(60),
retry=retry_if_not_exception_type(TaskNotFoundError),
):
with attempt: # noqa: SIM117
with pytest.raises(TaskNotFoundError):
# use internals to detirmine when it's no longer here
await manager._tasks_manager._get_tracked_task( # noqa: SLF001
task_id, task_context
)
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from typing import Annotated, Any
from typing import Any

from aiohttp import web
from models_library.rest_base import RequestParameters
from pydantic import BaseModel, Field
from pydantic import BaseModel

from ...aiohttp import status
from ...long_running_tasks import lrt_api
from ...long_running_tasks.models import TaskGet, TaskId
from ..requests_validation import (
parse_request_path_parameters_as,
parse_request_query_parameters_as,
)
from ..rest_responses import create_data_response
from ._manager import get_long_running_manager
Expand Down Expand Up @@ -69,29 +67,15 @@ async def get_task_result(request: web.Request) -> web.Response | Any:
)


class _RemoveTaskQueryParams(RequestParameters):
wait_for_removal: Annotated[
bool,
Field(
description=(
"when True waits for the task to be removed "
"completly instead of returning immediately"
)
),
] = True


@routes.delete("/{task_id}", name="remove_task")
async def remove_task(request: web.Request) -> web.Response:
path_params = parse_request_path_parameters_as(_PathParam, request)
query_params = parse_request_query_parameters_as(_RemoveTaskQueryParams, request)
long_running_manager = get_long_running_manager(request.app)

await lrt_api.remove_task(
long_running_manager.rpc_client,
long_running_manager.lrt_namespace,
long_running_manager.get_task_context(request),
path_params.task_id,
wait_for_removal=query_params.wait_for_removal,
)
return web.json_response(status=status.HTTP_204_NO_CONTENT)
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ async def start_long_running_task(
long_running_manager.lrt_namespace,
task_context,
task_id,
wait_for_removal=True,
)
raise

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Annotated, Any

from fastapi import APIRouter, Depends, Query, Request, status
from fastapi import APIRouter, Depends, Request, status

from ...long_running_tasks import lrt_api
from ...long_running_tasks.models import TaskGet, TaskId, TaskResult, TaskStatus
Expand Down Expand Up @@ -101,22 +101,11 @@ async def remove_task(
FastAPILongRunningManager, Depends(get_long_running_manager)
],
task_id: TaskId,
*,
wait_for_removal: Annotated[
bool,
Query(
description=(
"when True waits for the task to be removed "
"completly instead of returning immediately"
),
),
] = True,
) -> None:
assert request # nosec
await lrt_api.remove_task(
long_running_manager.rpc_client,
long_running_manager.lrt_namespace,
long_running_manager.get_task_context(request),
task_id=task_id,
wait_for_removal=wait_for_removal,
)
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from ..redis._client import RedisClientSDK
from ..redis._utils import handle_redis_returns_union_types
from ..utils import limited_gather
from .models import LRTNamespace, TaskContext, TaskData, TaskId
from .models import LRTNamespace, TaskData, TaskId

_STORE_TYPE_TASK_DATA: Final[str] = "TD"
_STORE_TYPE_CANCELLED_TASKS: Final[str] = "CT"
_LIST_CONCURRENCY: Final[int] = 2
_LIST_CONCURRENCY: Final[int] = 3
_MARKED_FOR_REMOVAL_FIELD: Final[str] = "marked_for_removal"


def _to_redis_hash_mapping(data: dict[str, Any]) -> dict[str, str]:
Expand Down Expand Up @@ -52,11 +52,6 @@ def _get_redis_key_task_data_match(self) -> str:
def _get_redis_task_data_key(self, task_id: TaskId) -> str:
return f"{self.namespace}:{_STORE_TYPE_TASK_DATA}:{task_id}"

def _get_key_to_remove(self) -> str:
return f"{self.namespace}:{_STORE_TYPE_CANCELLED_TASKS}"

# TaskData

async def get_task_data(self, task_id: TaskId) -> TaskData | None:
result: dict[str, Any] = await handle_redis_returns_union_types(
self._redis.hgetall(
Expand Down Expand Up @@ -115,24 +110,18 @@ async def delete_task_data(self, task_id: TaskId) -> None:
self._redis.delete(self._get_redis_task_data_key(task_id))
)

# to cancel

async def mark_task_for_removal(
self, task_id: TaskId, with_task_context: TaskContext
) -> None:
async def mark_for_removal(self, task_id: TaskId) -> None:
await handle_redis_returns_union_types(
self._redis.hset(
self._get_key_to_remove(), task_id, json_dumps(with_task_context)
self._get_redis_task_data_key(task_id),
mapping=_to_redis_hash_mapping({_MARKED_FOR_REMOVAL_FIELD: True}),
)
)

async def completed_task_removal(self, task_id: TaskId) -> None:
await handle_redis_returns_union_types(
self._redis.hdel(self._get_key_to_remove(), task_id)
)

async def list_tasks_to_remove(self) -> dict[TaskId, TaskContext]:
result: dict[str, str | None] = await handle_redis_returns_union_types(
self._redis.hgetall(self._get_key_to_remove())
async def is_marked_for_removal(self, task_id: TaskId) -> bool:
result = await handle_redis_returns_union_types(
self._redis.hget(
self._get_redis_task_data_key(task_id), _MARKED_FOR_REMOVAL_FIELD
)
)
return {task_id: json_loads(context) for task_id, context in result.items()}
return False if result is None else json_loads(result)
Original file line number Diff line number Diff line change
Expand Up @@ -118,26 +118,13 @@ async def remove_task(
*,
task_context: TaskContext,
task_id: TaskId,
wait_for_removal: bool,
cancellation_timeout: timedelta | None,
) -> None:
timeout_s = (
None
if cancellation_timeout is None
else int(cancellation_timeout.total_seconds())
)

# NOTE: task always gets cancelled even if not waiting for it
# request will return immediatlye, no need to wait so much
if wait_for_removal is False:
timeout_s = _RPC_TIMEOUT_SHORT_REQUESTS

result = await rabbitmq_rpc_client.request(
get_rabbit_namespace(namespace),
TypeAdapter(RPCMethodName).validate_python("remove_task"),
task_context=task_context,
task_id=task_id,
wait_for_removal=wait_for_removal,
timeout_s=timeout_s,
timeout_s=_RPC_TIMEOUT_SHORT_REQUESTS,
)
assert result is None # nosec
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def get_task_result(
await long_running_manager.tasks_manager.remove_task(
task_id,
with_task_context=task_context,
wait_for_removal=True,
wait_for_removal=False,
)


Expand All @@ -98,10 +98,7 @@ async def remove_task(
*,
task_context: TaskContext,
task_id: TaskId,
wait_for_removal: bool,
) -> None:
await long_running_manager.tasks_manager.remove_task(
task_id,
with_task_context=task_context,
wait_for_removal=wait_for_removal,
task_id, with_task_context=task_context, wait_for_removal=False
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import timedelta
from typing import Any

from ..rabbitmq._client_rpc import RabbitMQRPCClient
Expand Down Expand Up @@ -103,9 +102,6 @@ async def remove_task(
lrt_namespace: LRTNamespace,
task_context: TaskContext,
task_id: TaskId,
*,
wait_for_removal: bool,
cancellation_timeout: timedelta | None = None,
) -> None:
"""cancels and removes a task

Expand All @@ -116,6 +112,4 @@ async def remove_task(
lrt_namespace,
task_id=task_id,
task_context=task_context,
wait_for_removal=wait_for_removal,
cancellation_timeout=cancellation_timeout,
)
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class TaskData(BaseModel):
task_id: str
task_progress: TaskProgress
# NOTE: this context lifetime is with the tracked task (similar to aiohttp storage concept)
task_context: dict[str, Any]
task_context: TaskContext
fire_and_forget: Annotated[
bool,
Field(
Expand Down Expand Up @@ -78,6 +78,10 @@ class TaskData(BaseModel):
result_field: Annotated[
ResultField | None, Field(description="the result of the task")
] = None
marked_for_removal: Annotated[
bool,
Field(description=("if True, indicates the task is marked for removal")),
] = False

model_config = ConfigDict(
arbitrary_types_allowed=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def _stale_tasks_monitor(self) -> None:
# we just print the status from where one can infer the above
with suppress(TaskNotFoundError):
task_status = await self.get_task_status(
task_id, with_task_context=task_context
task_id, with_task_context=task_context, exclude_to_remove=False
)
with log_context(
_logger,
Expand All @@ -300,11 +300,17 @@ async def _cancelled_tasks_removal(self) -> None:
"""
self._started_event_task_cancelled_tasks_removal.set()

to_remove = await self._tasks_data.list_tasks_to_remove()
for task_id in to_remove:
await self._attempt_to_remove_local_task(task_id)
tasks_data = await self._tasks_data.list_tasks_data()
await limited_gather(
*(
self._attempt_to_remove_local_task(x.task_id)
for x in tasks_data
if x.marked_for_removal is True
),
limit=_PARALLEL_TASKS_CANCELLATION,
)

async def _tasks_monitor(self) -> None:
async def _tasks_monitor(self) -> None: # noqa: C901
"""
A task which monitors locally running tasks and updates their status
in the Redis store when they are done.
Expand Down Expand Up @@ -396,12 +402,14 @@ async def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBa
return [
TaskBase(task_id=task.task_id)
for task in (await self._tasks_data.list_tasks_data())
if task.marked_for_removal is False
]

return [
TaskBase(task_id=task.task_id)
for task in (await self._tasks_data.list_tasks_data())
if task.task_context == with_task_context
and task.marked_for_removal is False
]

async def _get_tracked_task(
Expand All @@ -418,14 +426,21 @@ async def _get_tracked_task(
return task_data

async def get_task_status(
self, task_id: TaskId, with_task_context: TaskContext
self,
task_id: TaskId,
with_task_context: TaskContext,
*,
exclude_to_remove: bool = True,
) -> TaskStatus:
"""
returns: the status of the task, along with updates
form the progress

raises TaskNotFoundError if the task cannot be found
"""
if exclude_to_remove and await self._tasks_data.is_marked_for_removal(task_id):
raise TaskNotFoundError(task_id=task_id)

task_data = await self._get_tracked_task(task_id, with_task_context)

await self._tasks_data.update_task_data(
Expand Down Expand Up @@ -460,6 +475,9 @@ async def get_task_result(
raises TaskNotFoundError if the task cannot be found
raises TaskNotCompletedError if the task is not completed
"""
if await self._tasks_data.is_marked_for_removal(task_id):
raise TaskNotFoundError(task_id=task_id)

tracked_task = await self._get_tracked_task(task_id, with_task_context)

if not tracked_task.is_done or tracked_task.result_field is None:
Expand All @@ -473,7 +491,6 @@ async def _attempt_to_remove_local_task(self, task_id: TaskId) -> None:
task_to_cancel = self._created_tasks.pop(task_id, None)
if task_to_cancel is not None:
await cancel_wait_task(task_to_cancel)
await self._tasks_data.completed_task_removal(task_id)
await self._tasks_data.delete_task_data(task_id)

async def remove_task(
Expand All @@ -487,11 +504,12 @@ async def remove_task(
cancels and removes task
raises TaskNotFoundError if the task cannot be found
"""
if await self._tasks_data.is_marked_for_removal(task_id):
raise TaskNotFoundError(task_id=task_id)

tracked_task = await self._get_tracked_task(task_id, with_task_context)

await self._tasks_data.mark_task_for_removal(
tracked_task.task_id, tracked_task.task_context
)
await self._tasks_data.mark_for_removal(tracked_task.task_id)

if not wait_for_removal:
return
Expand Down
Loading
Loading