Skip to content

Commit 10df578

Browse files
committed
Merge branch 'master' into 1973-add-celery-worker-to-api-server
2 parents d829b01 + b5836c9 commit 10df578

File tree

67 files changed

+1089
-451
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+1089
-451
lines changed

packages/models-library/src/models_library/functions.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from models_library.users import UserID
1414
from models_library.utils.enums import StrAutoEnum
1515
from pydantic import BaseModel, ConfigDict, Field
16-
from servicelib.celery.models import TaskID
1716

1817
from .projects import ProjectID
1918
from .utils.change_case import snake_to_camel
2019

20+
TaskID: TypeAlias = str
2121
FunctionID: TypeAlias = UUID
2222
FunctionJobID: TypeAlias = UUID
2323
FileID: TypeAlias = UUID
@@ -206,9 +206,11 @@ class RegisteredProjectFunctionJob(ProjectFunctionJob, RegisteredFunctionJobBase
206206

207207

208208
class RegisteredProjectFunctionJobPatch(BaseModel):
209-
function_class: FunctionClass
209+
function_class: Literal[FunctionClass.PROJECT] = FunctionClass.PROJECT
210210
title: str | None
211211
description: str | None
212+
inputs: FunctionInputs
213+
outputs: FunctionOutputs
212214
project_job_id: ProjectID | None
213215
job_creation_task_id: TaskID | None
214216

@@ -224,9 +226,11 @@ class RegisteredSolverFunctionJob(SolverFunctionJob, RegisteredFunctionJobBase):
224226

225227

226228
class RegisteredSolverFunctionJobPatch(BaseModel):
227-
function_class: FunctionClass
229+
function_class: Literal[FunctionClass.SOLVER] = FunctionClass.SOLVER
228230
title: str | None
229231
description: str | None
232+
inputs: FunctionInputs
233+
outputs: FunctionOutputs
230234
solver_job_id: ProjectID | None
231235
job_creation_task_id: TaskID | None
232236

@@ -240,8 +244,10 @@ class RegisteredPythonCodeFunctionJob(PythonCodeFunctionJob, RegisteredFunctionJ
240244

241245

242246
class RegisteredPythonCodeFunctionJobPatch(BaseModel):
243-
function_class: FunctionClass
247+
function_class: Literal[FunctionClass.PYTHON_CODE] = FunctionClass.PYTHON_CODE
244248
title: str | None
249+
inputs: FunctionInputs
250+
outputs: FunctionOutputs
245251
description: str | None
246252

247253

packages/models-library/src/models_library/functions_errors.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,8 @@ class FunctionJobCollectionsExecuteApiAccessDeniedError(FunctionBaseError):
158158
"User {user_id} does not have the permission to execute function job collections"
159159
)
160160
status_code: int = 403 # Forbidden
161+
162+
163+
class FunctionJobPatchModelIncompatibleError(FunctionBaseError):
164+
msg_template = "Incompatible patch model for Function '{function_id}' in product '{product_name}'."
165+
status_code: int = 422

packages/service-library/src/servicelib/aiohttp/long_running_tasks/_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from aiohttp import web
22

3-
from ...long_running_tasks.base_long_running_manager import BaseLongRunningManager
3+
from ...long_running_tasks.manager import LongRunningManager
44
from ...long_running_tasks.models import TaskContext
55
from ._constants import APP_LONG_RUNNING_MANAGER_KEY
66
from ._request import get_task_context
77

88

9-
class AiohttpLongRunningManager(BaseLongRunningManager):
9+
class AiohttpLongRunningManager(LongRunningManager):
1010

1111
@staticmethod
1212
def get_task_context(request: web.Request) -> TaskContext:

packages/service-library/src/servicelib/fastapi/long_running_tasks/_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from fastapi import Request
22

3-
from ...long_running_tasks.base_long_running_manager import BaseLongRunningManager
3+
from ...long_running_tasks.manager import LongRunningManager
44
from ...long_running_tasks.models import TaskContext
55

66

7-
class FastAPILongRunningManager(BaseLongRunningManager):
7+
class FastAPILongRunningManager(LongRunningManager):
88
@staticmethod
99
def get_task_context(request: Request) -> TaskContext:
1010
_ = request

packages/service-library/src/servicelib/long_running_tasks/_rpc_server.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
_logger = logging.getLogger(__name__)
1616

1717
if TYPE_CHECKING:
18-
from .base_long_running_manager import BaseLongRunningManager
18+
from .manager import LongRunningManager
1919

2020

2121
router = RPCRouter()
2222

2323

2424
@router.expose(reraise_if_error_type=(BaseLongRunningError,))
2525
async def start_task(
26-
long_running_manager: "BaseLongRunningManager",
26+
long_running_manager: "LongRunningManager",
2727
*,
2828
registered_task_name: RegisteredTaskName,
2929
unique: bool = False,
@@ -44,7 +44,7 @@ async def start_task(
4444

4545
@router.expose(reraise_if_error_type=(BaseLongRunningError,))
4646
async def list_tasks(
47-
long_running_manager: "BaseLongRunningManager", *, task_context: TaskContext
47+
long_running_manager: "LongRunningManager", *, task_context: TaskContext
4848
) -> list[TaskBase]:
4949
return await long_running_manager.tasks_manager.list_tasks(
5050
with_task_context=task_context
@@ -53,7 +53,7 @@ async def list_tasks(
5353

5454
@router.expose(reraise_if_error_type=(BaseLongRunningError,))
5555
async def get_task_status(
56-
long_running_manager: "BaseLongRunningManager",
56+
long_running_manager: "LongRunningManager",
5757
*,
5858
task_context: TaskContext,
5959
task_id: TaskId,
@@ -65,7 +65,7 @@ async def get_task_status(
6565

6666
@router.expose(reraise_if_error_type=(BaseLongRunningError, RPCTransferrableTaskError))
6767
async def get_task_result(
68-
long_running_manager: "BaseLongRunningManager",
68+
long_running_manager: "LongRunningManager",
6969
*,
7070
task_context: TaskContext,
7171
task_id: TaskId,
@@ -94,7 +94,7 @@ async def get_task_result(
9494

9595
@router.expose(reraise_if_error_type=(BaseLongRunningError,))
9696
async def remove_task(
97-
long_running_manager: "BaseLongRunningManager",
97+
long_running_manager: "LongRunningManager",
9898
*,
9999
task_context: TaskContext,
100100
task_id: TaskId,
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import logging
2+
3+
import redis.asyncio as aioredis
4+
from settings_library.redis import RedisDatabase, RedisSettings
5+
6+
from ..logging_utils import log_context
7+
from ..redis._client import RedisClientSDK
8+
from .models import LRTNamespace
9+
10+
_logger = logging.getLogger(__name__)
11+
12+
13+
class LongRunningClientHelper:
14+
def __init__(self, redis_settings: RedisSettings):
15+
self.redis_settings = redis_settings
16+
17+
self._client: RedisClientSDK | None = None
18+
19+
async def setup(self) -> None:
20+
self._client = RedisClientSDK(
21+
self.redis_settings.build_redis_dsn(RedisDatabase.LONG_RUNNING_TASKS),
22+
client_name="long_running_tasks_cleanup_client",
23+
)
24+
await self._client.setup()
25+
26+
async def shutdown(self) -> None:
27+
if self._client:
28+
await self._client.shutdown()
29+
30+
@property
31+
def _redis(self) -> aioredis.Redis:
32+
assert self._client # nosec
33+
return self._client.redis
34+
35+
async def cleanup(self, lrt_namespace: LRTNamespace) -> None:
36+
"""removes Redis keys associated to the LRTNamespace if they exist"""
37+
keys_to_remove: list[str] = [
38+
x async for x in self._redis.scan_iter(f"{lrt_namespace}*")
39+
]
40+
with log_context(
41+
_logger,
42+
logging.DEBUG,
43+
msg=f"Removing {keys_to_remove=} from Redis for {lrt_namespace=}",
44+
):
45+
if len(keys_to_remove) > 0:
46+
await self._redis.delete(*keys_to_remove)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .task import TasksManager
1212

1313

14-
class BaseLongRunningManager(ABC):
14+
class LongRunningManager(ABC):
1515
"""
1616
Provides a commond inteface for aiohttp and fastapi services
1717
"""

packages/service-library/src/servicelib/long_running_tasks/task.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from common_library.async_tools import cancel_wait_task
1212
from models_library.api_schemas_long_running_tasks.base import TaskProgress
1313
from pydantic import NonNegativeFloat, PositiveFloat
14+
from servicelib.utils import limited_gather
1415
from settings_library.redis import RedisDatabase, RedisSettings
1516
from tenacity import (
1617
AsyncRetrying,
@@ -21,7 +22,7 @@
2122

2223
from ..background_task import create_periodic_task
2324
from ..logging_errors import create_troubleshootting_log_kwargs
24-
from ..logging_utils import log_context
25+
from ..logging_utils import log_catch, log_context
2526
from ..redis import RedisClientSDK, exclusive
2627
from ._redis_store import RedisStore
2728
from ._serialization import dumps
@@ -50,6 +51,7 @@
5051
_STATUS_UPDATE_CHECK_INTERNAL: Final[datetime.timedelta] = datetime.timedelta(seconds=1)
5152
_MAX_EXCLUSIVE_TASK_CANCEL_TIMEOUT: Final[NonNegativeFloat] = 5
5253
_TASK_REMOVAL_MAX_WAIT: Final[NonNegativeFloat] = 60
54+
_PARALLEL_TASKS_CANCELLATION: Final[int] = 5
5355

5456
AllowedErrrors: TypeAlias = tuple[type[BaseException], ...]
5557

@@ -205,34 +207,40 @@ async def setup(self) -> None:
205207
await self._started_event_task_tasks_monitor.wait()
206208

207209
async def teardown(self) -> None:
208-
# ensure all created tasks are cancelled
209-
for tracked_task in await self._tasks_data.list_tasks_data():
210-
with suppress(TaskNotFoundError):
210+
# stop cancelled_tasks_removal
211+
if self._task_cancelled_tasks_removal:
212+
await cancel_wait_task(self._task_cancelled_tasks_removal)
213+
214+
# stopping only tasks that are handled by this manager
215+
# otherwise it will cancel long running tasks that were running in diffierent processes
216+
async def _remove_local_task(task_data: TaskData) -> None:
217+
with log_catch(_logger, reraise=False):
211218
await self.remove_task(
212-
tracked_task.task_id,
213-
tracked_task.task_context,
214-
wait_for_removal=True,
219+
task_data.task_id,
220+
task_data.task_context,
221+
wait_for_removal=False,
215222
)
223+
await self._attempt_to_remove_local_task(task_data.task_id)
224+
225+
await limited_gather(
226+
*[
227+
_remove_local_task(tracked_task)
228+
for task_id in self._created_tasks
229+
if (tracked_task := await self._tasks_data.get_task_data(task_id))
230+
is not None
231+
],
232+
log=_logger,
233+
limit=_PARALLEL_TASKS_CANCELLATION,
234+
)
216235

217-
for task in self._created_tasks.values():
218-
_logger.warning(
219-
"Task %s was not completed before shutdown, cancelling it",
220-
task.get_name(),
221-
)
222-
await cancel_wait_task(task)
223-
224-
# stale_tasks_monitor
236+
# stop stale_tasks_monitor
225237
if self._task_stale_tasks_monitor:
226238
await cancel_wait_task(
227239
self._task_stale_tasks_monitor,
228240
max_delay=_MAX_EXCLUSIVE_TASK_CANCEL_TIMEOUT,
229241
)
230242

231-
# cancelled_tasks_removal
232-
if self._task_cancelled_tasks_removal:
233-
await cancel_wait_task(self._task_cancelled_tasks_removal)
234-
235-
# tasks_monitor
243+
# stop tasks_monitor
236244
if self._task_tasks_monitor:
237245
await cancel_wait_task(self._task_tasks_monitor)
238246

packages/service-library/tests/long_running_tasks/conftest.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from faker import Faker
1212
from pytest_mock import MockerFixture
1313
from servicelib.logging_utils import log_catch
14-
from servicelib.long_running_tasks.base_long_running_manager import (
15-
BaseLongRunningManager,
14+
from servicelib.long_running_tasks.manager import (
15+
LongRunningManager,
1616
)
1717
from servicelib.long_running_tasks.models import LRTNamespace, TaskContext
1818
from servicelib.long_running_tasks.task import TasksManager
@@ -24,7 +24,7 @@
2424
_logger = logging.getLogger(__name__)
2525

2626

27-
class _TestingLongRunningManager(BaseLongRunningManager):
27+
class _TestingLongRunningManager(LongRunningManager):
2828
@staticmethod
2929
def get_task_context(request) -> TaskContext:
3030
_ = request
@@ -37,16 +37,16 @@ async def get_long_running_manager(
3737
) -> AsyncIterator[
3838
Callable[
3939
[RedisSettings, RabbitSettings, LRTNamespace | None],
40-
Awaitable[BaseLongRunningManager],
40+
Awaitable[LongRunningManager],
4141
]
4242
]:
43-
managers: list[BaseLongRunningManager] = []
43+
managers: list[LongRunningManager] = []
4444

4545
async def _(
4646
redis_settings: RedisSettings,
4747
rabbit_settings: RabbitSettings,
4848
lrt_namespace: LRTNamespace | None,
49-
) -> BaseLongRunningManager:
49+
) -> LongRunningManager:
5050
manager = _TestingLongRunningManager(
5151
stale_task_check_interval=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S),
5252
stale_task_detect_timeout=timedelta(seconds=TEST_CHECK_STALE_INTERVAL_S),

0 commit comments

Comments
 (0)