Skip to content

Commit 381dc67

Browse files
πŸŽ¨πŸ› Fix filtering bug in celery tasks (#8355)
1 parent ae14233 commit 381dc67

File tree

31 files changed

+361
-185
lines changed

31 files changed

+361
-185
lines changed

β€Žpackages/celery-library/src/celery_library/backends/redis.pyβ€Ž

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
11
import contextlib
22
import logging
33
from datetime import timedelta
4-
from typing import Final
4+
from typing import TYPE_CHECKING, Final
55

66
from models_library.progress_bar import ProgressReport
77
from pydantic import ValidationError
88
from servicelib.celery.models import (
99
Task,
1010
TaskFilter,
1111
TaskID,
12+
TaskInfoStore,
1213
TaskMetadata,
13-
TaskUUID,
14+
Wildcard,
1415
)
1516
from servicelib.redis import RedisClientSDK, handle_redis_returns_union_types
1617

17-
from ..utils import build_task_id_prefix
18-
1918
_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
2019
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
21-
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
2220
_CELERY_TASK_SCAN_COUNT_PER_BATCH: Final[int] = 1000
2321
_CELERY_TASK_METADATA_KEY: Final[str] = "metadata"
2422
_CELERY_TASK_PROGRESS_KEY: Final[str] = "progress"
@@ -88,17 +86,14 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
8886
return None
8987

9088
async def list_tasks(self, task_filter: TaskFilter) -> list[Task]:
91-
search_key = (
92-
_CELERY_TASK_INFO_PREFIX
93-
+ build_task_id_prefix(task_filter)
94-
+ _CELERY_TASK_ID_KEY_SEPARATOR
89+
search_key = _CELERY_TASK_INFO_PREFIX + task_filter.create_task_id(
90+
task_uuid=Wildcard()
9591
)
96-
search_key_len = len(search_key)
9792

9893
keys: list[str] = []
9994
pipeline = self._redis_client_sdk.redis.pipeline()
10095
async for key in self._redis_client_sdk.redis.scan_iter(
101-
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
96+
match=search_key, count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
10297
):
10398
# fake redis (tests) returns bytes, real redis returns str
10499
_key = (
@@ -120,7 +115,7 @@ async def list_tasks(self, task_filter: TaskFilter) -> list[Task]:
120115
task_metadata = TaskMetadata.model_validate_json(raw_metadata)
121116
tasks.append(
122117
Task(
123-
uuid=TaskUUID(key[search_key_len:]),
118+
uuid=TaskFilter.get_task_uuid(key),
124119
metadata=task_metadata,
125120
)
126121
)
@@ -143,3 +138,7 @@ async def task_exists(self, task_id: TaskID) -> bool:
143138
n = await self._redis_client_sdk.redis.exists(_build_key(task_id))
144139
assert isinstance(n, int) # nosec
145140
return n > 0
141+
142+
143+
if TYPE_CHECKING:
144+
_: type[TaskInfoStore] = RedisTaskInfoStore

β€Žpackages/celery-library/src/celery_library/rpc/_async_jobs.pyβ€Ž

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,8 @@ async def result(
134134

135135
@router.expose(reraise_if_error_type=(JobSchedulerError,))
136136
async def list_jobs(
137-
task_manager: TaskManager, filter_: str, job_filter: AsyncJobFilter
137+
task_manager: TaskManager, job_filter: AsyncJobFilter
138138
) -> list[AsyncJobGet]:
139-
_ = filter_
140139
assert task_manager # nosec
141140
task_filter = TaskFilter.model_validate(job_filter.model_dump())
142141
try:

β€Žpackages/celery-library/src/celery_library/task_manager.pyβ€Ž

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from settings_library.celery import CelerySettings
2323

2424
from .errors import TaskNotFoundError
25-
from .utils import build_task_id
2625

2726
_logger = logging.getLogger(__name__)
2827

@@ -50,7 +49,7 @@ async def submit_task(
5049
msg=f"Submit {task_metadata.name=}: {task_filter=} {task_params=}",
5150
):
5251
task_uuid = uuid4()
53-
task_id = build_task_id(task_filter, task_uuid)
52+
task_id = task_filter.create_task_id(task_uuid=task_uuid)
5453
self._celery_app.send_task(
5554
task_metadata.name,
5655
task_id=task_id,
@@ -74,7 +73,7 @@ async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> Non
7473
logging.DEBUG,
7574
msg=f"task cancellation: {task_filter=} {task_uuid=}",
7675
):
77-
task_id = build_task_id(task_filter, task_uuid)
76+
task_id = task_filter.create_task_id(task_uuid=task_uuid)
7877
if not await self.task_exists(task_id):
7978
raise TaskNotFoundError(task_id=task_id)
8079

@@ -96,7 +95,7 @@ async def get_task_result(
9695
logging.DEBUG,
9796
msg=f"Get task result: {task_filter=} {task_uuid=}",
9897
):
99-
task_id = build_task_id(task_filter, task_uuid)
98+
task_id = task_filter.create_task_id(task_uuid=task_uuid)
10099
if not await self.task_exists(task_id):
101100
raise TaskNotFoundError(task_id=task_id)
102101

@@ -139,7 +138,7 @@ async def get_task_status(
139138
logging.DEBUG,
140139
msg=f"Getting task status: {task_filter=} {task_uuid=}",
141140
):
142-
task_id = build_task_id(task_filter, task_uuid)
141+
task_id = task_filter.create_task_id(task_uuid=task_uuid)
143142
if not await self.task_exists(task_id):
144143
raise TaskNotFoundError(task_id=task_id)
145144

β€Žpackages/celery-library/src/celery_library/utils.pyβ€Ž

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,8 @@
1-
from typing import Final
2-
31
from celery import Celery # type: ignore[import-untyped]
42
from servicelib.celery.app_server import BaseAppServer
5-
from servicelib.celery.models import TaskFilter, TaskID, TaskUUID
63

74
_APP_SERVER_KEY = "app_server"
85

9-
_TASK_ID_KEY_DELIMITATOR: Final[str] = ":"
10-
11-
12-
def build_task_id_prefix(task_filter: TaskFilter) -> str:
13-
filter_dict = task_filter.model_dump()
14-
return _TASK_ID_KEY_DELIMITATOR.join(
15-
[f"{filter_dict[key]}" for key in sorted(filter_dict)]
16-
)
17-
18-
19-
def build_task_id(task_filter: TaskFilter, task_uuid: TaskUUID) -> TaskID:
20-
return _TASK_ID_KEY_DELIMITATOR.join(
21-
[build_task_id_prefix(task_filter), f"{task_uuid}"]
22-
)
23-
246

257
def get_app_server(app: Celery) -> BaseAppServer:
268
app_server = app.conf[_APP_SERVER_KEY]

β€Žpackages/celery-library/tests/unit/test_async_jobs.pyβ€Ž

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,6 @@ async def test_async_jobs_workflow(
258258
jobs = await async_jobs.list_jobs(
259259
async_jobs_rabbitmq_rpc_client,
260260
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
261-
filter_="", # currently not used
262261
job_filter=job_filter,
263262
)
264263
assert len(jobs) > 0
@@ -311,7 +310,6 @@ async def test_async_jobs_cancel(
311310
jobs = await async_jobs.list_jobs(
312311
async_jobs_rabbitmq_rpc_client,
313312
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
314-
filter_="", # currently not used
315313
job_filter=job_filter,
316314
)
317315
assert async_job_get.job_id not in [job.job_id for job in jobs]

β€Žpackages/celery-library/tests/unit/test_tasks.pyβ€Ž

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,21 @@
1818
from celery_library.task_manager import CeleryTaskManager
1919
from celery_library.utils import get_app_server
2020
from common_library.errors_classes import OsparcErrorMixin
21+
from faker import Faker
2122
from models_library.progress_bar import ProgressReport
2223
from servicelib.celery.models import (
2324
TaskFilter,
2425
TaskID,
2526
TaskMetadata,
2627
TaskState,
28+
TaskUUID,
29+
Wildcard,
2730
)
2831
from servicelib.logging_utils import log_context
2932
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed
3033

34+
_faker = Faker()
35+
3136
_logger = logging.getLogger(__name__)
3237

3338
pytest_simcore_core_services_selection = ["redis"]
@@ -199,5 +204,53 @@ async def test_listing_task_uuids_contains_submitted_task(
199204
tasks = await celery_task_manager.list_tasks(task_filter)
200205
assert any(task.uuid == task_uuid for task in tasks)
201206

202-
tasks = await celery_task_manager.list_tasks(task_filter)
203-
assert any(task.uuid == task_uuid for task in tasks)
207+
tasks = await celery_task_manager.list_tasks(task_filter)
208+
assert any(task.uuid == task_uuid for task in tasks)
209+
210+
211+
async def test_filtering_listing_tasks(
212+
celery_task_manager: CeleryTaskManager,
213+
):
214+
class MyFilter(TaskFilter):
215+
user_id: int
216+
product_name: str | Wildcard
217+
client_app: str | Wildcard
218+
219+
user_id = 42
220+
expected_task_uuids: set[TaskUUID] = set()
221+
222+
for _ in range(5):
223+
task_filter = MyFilter(
224+
user_id=user_id,
225+
product_name=_faker.word(),
226+
client_app=_faker.word(),
227+
)
228+
task_uuid = await celery_task_manager.submit_task(
229+
TaskMetadata(
230+
name=dreamer_task.__name__,
231+
),
232+
task_filter=task_filter,
233+
)
234+
expected_task_uuids.add(task_uuid)
235+
236+
for _ in range(3):
237+
task_filter = MyFilter(
238+
user_id=_faker.pyint(min_value=100, max_value=200),
239+
product_name=_faker.word(),
240+
client_app=_faker.word(),
241+
)
242+
await celery_task_manager.submit_task(
243+
TaskMetadata(
244+
name=dreamer_task.__name__,
245+
),
246+
task_filter=task_filter,
247+
)
248+
249+
search_filter = MyFilter(
250+
user_id=user_id,
251+
product_name=Wildcard(),
252+
client_app=Wildcard(),
253+
)
254+
tasks = await celery_task_manager.list_tasks(search_filter)
255+
assert expected_task_uuids == {task.uuid for task in tasks}
256+
await asyncio.sleep(5 * 60)

β€Žpackages/models-library/src/models_library/api_schemas_rpc_async_jobs/async_jobs.pyβ€Ž

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ class AsyncJobFilter(AsyncJobFilterBase):
6767

6868
product_name: ProductName
6969
user_id: UserID
70-
client_name: Annotated[
70+
client_name: Annotated[ # this is the name of the app which *submits* the async job. It is mainly used for filtering purposes
7171
str,
7272
StringConstraints(min_length=1, pattern=r"^[^\s]+$"),
7373
]

β€Žpackages/pytest-simcore/src/pytest_simcore/helpers/storage_rpc_server.pyβ€Ž

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
AsyncJobGet,
1414
)
1515
from models_library.api_schemas_webserver.storage import PathToExport
16-
from models_library.products import ProductName
17-
from models_library.users import UserID
1816
from pydantic import TypeAdapter, validate_call
1917
from pytest_mock import MockType
2018
from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient
@@ -27,14 +25,12 @@ async def start_export_data(
2725
self,
2826
rabbitmq_rpc_client: RabbitMQRPCClient | MockType,
2927
*,
30-
user_id: UserID,
31-
product_name: ProductName,
3228
paths_to_export: list[PathToExport],
3329
export_as: Literal["path", "download_link"],
30+
job_filter: AsyncJobFilter,
3431
) -> tuple[AsyncJobGet, AsyncJobFilter]:
3532
assert rabbitmq_rpc_client
36-
assert user_id
37-
assert product_name
33+
assert job_filter
3834
assert paths_to_export
3935
assert export_as
4036

β€Žpackages/service-library/src/servicelib/celery/models.pyβ€Ž

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,103 @@
11
import datetime
22
from enum import StrEnum
3-
from typing import Annotated, Final, Protocol, TypeAlias
3+
from typing import Annotated, Any, Final, Protocol, Self, TypeAlias, TypeVar
44
from uuid import UUID
55

66
from models_library.progress_bar import ProgressReport
7-
from pydantic import BaseModel, ConfigDict, StringConstraints
7+
from pydantic import BaseModel, ConfigDict, StringConstraints, model_validator
88
from pydantic.config import JsonDict
99

10+
ModelType = TypeVar("ModelType", bound=BaseModel)
11+
1012
TaskID: TypeAlias = str
1113
TaskName: TypeAlias = Annotated[
1214
str, StringConstraints(strip_whitespace=True, min_length=1)
1315
]
1416
TaskUUID: TypeAlias = UUID
17+
_TASK_ID_KEY_DELIMITATOR: Final[str] = ":"
18+
_WILDCARD: Final[str] = "*"
19+
_FORBIDDEN_CHARS = (_WILDCARD, _TASK_ID_KEY_DELIMITATOR, "=")
20+
21+
22+
class Wildcard:
23+
def __str__(self) -> str:
24+
return _WILDCARD
25+
26+
27+
class TaskFilter(BaseModel):
28+
"""
29+
Class for associating metadata with a celery task. The implementation is very flexible and allows "clients" to define their own metadata.
30+
The class exposes a filtering mechanism to list tasks using wildcards.
31+
32+
Example usage:
33+
class MyTaskFilter(TaskFilter):
34+
user_id: int | Wildcard
35+
product_name: int | Wildcard
36+
client_name: str
37+
38+
Listing tasks using the filter `MyTaskFilter(user_id=123, product_name=Wildcard(), client_name="my-app")` will return all tasks with
39+
user_id 123, any product_name submitted from my-app.
40+
41+
If the metadata schema is known, the class allows deserializing the metadata (recreate_as_model). I.e. one can recover the metadata from the task:
42+
metadata -> task_uuid -> metadata
43+
44+
"""
45+
46+
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
47+
48+
@model_validator(mode="after")
49+
def _check_valid_filters(self) -> Self:
50+
for key, value in self.model_dump().items():
51+
# forbidden keys
52+
if any(x in key for x in _FORBIDDEN_CHARS):
53+
raise ValueError(f"Invalid filter key: '{key}'")
54+
# forbidden values
55+
if not isinstance(value, Wildcard) and any(
56+
x in f"{value}" for x in _FORBIDDEN_CHARS
57+
):
58+
raise ValueError(f"Invalid filter value for key '{key}': '{value}'")
59+
return self
60+
61+
def _build_task_id_prefix(self) -> str:
62+
filter_dict = self.model_dump()
63+
return _TASK_ID_KEY_DELIMITATOR.join(
64+
[f"{key}={filter_dict[key]}" for key in sorted(filter_dict)]
65+
)
1566

67+
def create_task_id(self, task_uuid: TaskUUID | Wildcard) -> TaskID:
68+
return _TASK_ID_KEY_DELIMITATOR.join(
69+
[
70+
self._build_task_id_prefix(),
71+
f"task_uuid={task_uuid}",
72+
]
73+
)
1674

17-
class TaskFilter(BaseModel): ...
75+
@classmethod
76+
def recreate_as_model(cls, task_id: TaskID, schema: type[ModelType]) -> ModelType:
77+
filter_dict = cls._recreate_data(task_id)
78+
return schema.model_validate(filter_dict)
79+
80+
@classmethod
81+
def _recreate_data(cls, task_id: TaskID) -> dict[str, Any]:
82+
"""Recreates the filter data from a task_id string
83+
WARNING: does not validate types. For that use `recreate_model` instead
84+
"""
85+
try:
86+
parts = task_id.split(_TASK_ID_KEY_DELIMITATOR)
87+
return {
88+
key: value
89+
for part in parts[:-1]
90+
if (key := part.split("=")[0]) and (value := part.split("=")[1])
91+
}
92+
except (IndexError, ValueError) as err:
93+
raise ValueError(f"Invalid task_id format: {task_id}") from err
94+
95+
@classmethod
96+
def get_task_uuid(cls, task_id: TaskID) -> TaskUUID:
97+
try:
98+
return UUID(task_id.split(_TASK_ID_KEY_DELIMITATOR)[-1].split("=")[1])
99+
except (IndexError, ValueError) as err:
100+
raise ValueError(f"Invalid task_id format: {task_id}") from err
18101

19102

20103
class TaskState(StrEnum):

β€Žpackages/service-library/src/servicelib/rabbitmq/rpc_interfaces/async_jobs/async_jobs.pyβ€Ž

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,11 @@ async def list_jobs(
9292
rabbitmq_rpc_client: RabbitMQRPCClient,
9393
*,
9494
rpc_namespace: RPCNamespace,
95-
filter_: str,
9695
job_filter: AsyncJobFilter,
9796
) -> list[AsyncJobGet]:
9897
_result: list[AsyncJobGet] = await rabbitmq_rpc_client.request(
9998
rpc_namespace,
10099
TypeAdapter(RPCMethodName).validate_python("list_jobs"),
101-
filter_=filter_,
102100
job_filter=job_filter,
103101
timeout_s=_DEFAULT_TIMEOUT_S,
104102
)

0 commit comments

Comments
Β (0)