Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# pylint: disable=protected-access

import pytest
from fastapi import FastAPI
from servicelib.long_running_tasks.errors import TaskNotFoundError
from servicelib.long_running_tasks.manager import (
LongRunningManager,
)
from servicelib.long_running_tasks.models import TaskContext
from servicelib.long_running_tasks.task import TaskId
from tenacity import (
AsyncRetrying,
TryAgain,
retry_if_exception_type,
stop_after_delay,
wait_fixed,
)


def get_fastapi_long_running_manager(app: FastAPI) -> LongRunningManager:
manager = app.state.long_running_manager
assert isinstance(manager, LongRunningManager)
return manager


async def assert_task_is_no_longer_present(
manager: LongRunningManager, task_id: TaskId, task_context: TaskContext
) -> None:
async for attempt in AsyncRetrying(
reraise=True,
wait=wait_fixed(0.1),
stop=stop_after_delay(60),
retry=retry_if_exception_type((AssertionError, TryAgain)),
):
with attempt: # noqa: SIM117
with pytest.raises(TaskNotFoundError): # noqa: PT012
# use internals to detirmine when it's no longer here
await manager._tasks_manager._get_tracked_task( # noqa: SLF001
task_id, task_context
)
raise TryAgain
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
from typing import Annotated, Any
from typing import Any

from aiohttp import web
from models_library.rest_base import RequestParameters
from pydantic import BaseModel, Field
from pydantic import BaseModel

from ...aiohttp import status
from ...long_running_tasks import lrt_api
from ...long_running_tasks.models import TaskGet, TaskId
from ..requests_validation import (
parse_request_path_parameters_as,
parse_request_query_parameters_as,
)
from ..rest_responses import create_data_response
from ._manager import get_long_running_manager
Expand Down Expand Up @@ -69,29 +67,15 @@ async def get_task_result(request: web.Request) -> web.Response | Any:
)


class _RemoveTaskQueryParams(RequestParameters):
wait_for_removal: Annotated[
bool,
Field(
description=(
"when True waits for the task to be removed "
"completly instead of returning immediately"
)
),
] = True


@routes.delete("/{task_id}", name="remove_task")
async def remove_task(request: web.Request) -> web.Response:
path_params = parse_request_path_parameters_as(_PathParam, request)
query_params = parse_request_query_parameters_as(_RemoveTaskQueryParams, request)
long_running_manager = get_long_running_manager(request.app)

await lrt_api.remove_task(
long_running_manager.rpc_client,
long_running_manager.lrt_namespace,
long_running_manager.get_task_context(request),
path_params.task_id,
wait_for_removal=query_params.wait_for_removal,
)
return web.json_response(status=status.HTTP_204_NO_CONTENT)
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ async def start_long_running_task(
long_running_manager.lrt_namespace,
task_context,
task_id,
wait_for_removal=True,
)
raise

Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Annotated, Any

from fastapi import APIRouter, Depends, Query, Request, status
from fastapi import APIRouter, Depends, Request, status

from ...long_running_tasks import lrt_api
from ...long_running_tasks.models import TaskGet, TaskId, TaskResult, TaskStatus
Expand Down Expand Up @@ -101,22 +101,11 @@ async def remove_task(
FastAPILongRunningManager, Depends(get_long_running_manager)
],
task_id: TaskId,
*,
wait_for_removal: Annotated[
bool,
Query(
description=(
"when True waits for the task to be removed "
"completly instead of returning immediately"
),
),
] = True,
) -> None:
assert request # nosec
await lrt_api.remove_task(
long_running_manager.rpc_client,
long_running_manager.lrt_namespace,
long_running_manager.get_task_context(request),
task_id=task_id,
wait_for_removal=wait_for_removal,
)
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,12 @@ async def mark_task_for_removal(
)
)

async def is_maked_for_removal(self, task_id: TaskId) -> bool:
result: bool = await handle_redis_returns_union_types(
self._redis.hexists(self._get_key_to_remove(), task_id)
)
return result

async def completed_task_removal(self, task_id: TaskId) -> None:
await handle_redis_returns_union_types(
self._redis.hdel(self._get_key_to_remove(), task_id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,26 +118,13 @@ async def remove_task(
*,
task_context: TaskContext,
task_id: TaskId,
wait_for_removal: bool,
cancellation_timeout: timedelta | None,
) -> None:
timeout_s = (
None
if cancellation_timeout is None
else int(cancellation_timeout.total_seconds())
)

# NOTE: task always gets cancelled even if not waiting for it
# request will return immediatlye, no need to wait so much
if wait_for_removal is False:
timeout_s = _RPC_TIMEOUT_SHORT_REQUESTS

result = await rabbitmq_rpc_client.request(
get_rabbit_namespace(namespace),
TypeAdapter(RPCMethodName).validate_python("remove_task"),
task_context=task_context,
task_id=task_id,
wait_for_removal=wait_for_removal,
timeout_s=timeout_s,
timeout_s=_RPC_TIMEOUT_SHORT_REQUESTS,
)
assert result is None # nosec
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def get_task_result(
await long_running_manager.tasks_manager.remove_task(
task_id,
with_task_context=task_context,
wait_for_removal=True,
wait_for_removal=False,
)


Expand All @@ -98,10 +98,7 @@ async def remove_task(
*,
task_context: TaskContext,
task_id: TaskId,
wait_for_removal: bool,
) -> None:
await long_running_manager.tasks_manager.remove_task(
task_id,
with_task_context=task_context,
wait_for_removal=wait_for_removal,
task_id, with_task_context=task_context, wait_for_removal=False
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from datetime import timedelta
from typing import Any

from ..rabbitmq._client_rpc import RabbitMQRPCClient
Expand Down Expand Up @@ -103,9 +102,6 @@ async def remove_task(
lrt_namespace: LRTNamespace,
task_context: TaskContext,
task_id: TaskId,
*,
wait_for_removal: bool,
cancellation_timeout: timedelta | None = None,
) -> None:
"""cancels and removes a task

Expand All @@ -116,6 +112,4 @@ async def remove_task(
lrt_namespace,
task_id=task_id,
task_context=task_context,
wait_for_removal=wait_for_removal,
cancellation_timeout=cancellation_timeout,
)
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ async def _stale_tasks_monitor(self) -> None:
# we just print the status from where one can infer the above
with suppress(TaskNotFoundError):
task_status = await self.get_task_status(
task_id, with_task_context=task_context
task_id, with_task_context=task_context, exclude_removed=False
)
with log_context(
_logger,
Expand Down Expand Up @@ -386,17 +386,26 @@ async def _tasks_monitor(self) -> None:
await self._tasks_data.update_task_data(task_id, updates=updates)

async def list_tasks(self, with_task_context: TaskContext | None) -> list[TaskBase]:
if not with_task_context:
async def _() -> list[TaskBase]:
if not with_task_context:
return [
TaskBase(task_id=task.task_id)
for task in (await self._tasks_data.list_tasks_data())
]

return [
TaskBase(task_id=task.task_id)
for task in (await self._tasks_data.list_tasks_data())
if task.task_context == with_task_context
]

return [
TaskBase(task_id=task.task_id)
for task in (await self._tasks_data.list_tasks_data())
if task.task_context == with_task_context
]
result = await _()

if len(result) == 0:
return []

to_remove = await self._tasks_data.list_tasks_to_remove()
return [r for r in result if r.task_id not in to_remove]

async def _get_tracked_task(
self, task_id: TaskId, with_task_context: TaskContext
Expand All @@ -412,14 +421,21 @@ async def _get_tracked_task(
return task_data

async def get_task_status(
self, task_id: TaskId, with_task_context: TaskContext
self,
task_id: TaskId,
with_task_context: TaskContext,
*,
exclude_removed: bool = True,
) -> TaskStatus:
"""
returns: the status of the task, along with updates
form the progress

raises TaskNotFoundError if the task cannot be found
"""
if exclude_removed and await self._tasks_data.is_maked_for_removal(task_id):
raise TaskNotFoundError(task_id=task_id)

task_data = await self._get_tracked_task(task_id, with_task_context)

await self._tasks_data.update_task_data(
Expand Down Expand Up @@ -454,6 +470,9 @@ async def get_task_result(
raises TaskNotFoundError if the task cannot be found
raises TaskNotCompletedError if the task is not completed
"""
if await self._tasks_data.is_maked_for_removal(task_id):
raise TaskNotFoundError(task_id=task_id)

tracked_task = await self._get_tracked_task(task_id, with_task_context)

if not tracked_task.is_done or tracked_task.result_field is None:
Expand Down Expand Up @@ -481,6 +500,9 @@ async def remove_task(
cancels and removes task
raises TaskNotFoundError if the task cannot be found
"""
if await self._tasks_data.is_maked_for_removal(task_id):
raise TaskNotFoundError(task_id=task_id)

tracked_task = await self._get_tracked_task(task_id, with_task_context)

await self._tasks_data.mark_task_for_removal(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ async def test_workflow(store: RedisStore, task_data: TaskData) -> None:
# cancelled tasks
assert await store.list_tasks_to_remove() == {}

assert await store.is_maked_for_removal(task_data.task_id) is False

await store.mark_task_for_removal(task_data.task_id, task_data.task_context)

assert await store.is_maked_for_removal(task_data.task_id) is True

assert await store.list_tasks_to_remove() == {
task_data.task_id: task_data.task_context
}
Expand Down
Loading
Loading