Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
35402ee
allow extra fields in TaskFilter and test
bisgaard-itis Sep 10, 2025
bbaf2a5
add test of field_sorting_key
bisgaard-itis Sep 10, 2025
3aeb4fb
Merge branch 'master' into 8333-fix-filtering-bug-in-celery-tasks
bisgaard-itis Sep 11, 2025
093b6c5
extend functionality of TaskFilter
bisgaard-itis Sep 11, 2025
bde276c
fix listing of tasks
bisgaard-itis Sep 11, 2025
6daf856
add test for filtering
bisgaard-itis Sep 11, 2025
2d01d69
upper -> lower case client ASYNC_JOB_CLIENT_NAMEs
bisgaard-itis Sep 11, 2025
494a8a2
remove filtering input to rpc async job interface
bisgaard-itis Sep 11, 2025
91b7fbc
Merge branch 'master' into 8333-fix-filtering-bug-in-celery-tasks
bisgaard-itis Sep 11, 2025
59b0064
minor cleanup
bisgaard-itis Sep 11, 2025
f9f04d6
pylint
bisgaard-itis Sep 11, 2025
e9f5d42
pylint
bisgaard-itis Sep 11, 2025
60a72aa
fix test
bisgaard-itis Sep 11, 2025
59fe947
Merge branch 'master' into 8333-fix-filtering-bug-in-celery-tasks
bisgaard-itis Sep 11, 2025
4f4b1f5
correct wrong usage of client_app_name
bisgaard-itis Sep 11, 2025
31fc291
add doc
bisgaard-itis Sep 11, 2025
6389fda
use app-name for ASYNC_JOB_CLIENT_NAME in api-server
bisgaard-itis Sep 11, 2025
33a2ae4
use app-name in storage when specifying client_name in AsyncJobFilter
bisgaard-itis Sep 11, 2025
be84da1
use app-name in AsyncJobFilter in webserver
bisgaard-itis Sep 11, 2025
9a8d113
cleanup in webserver
bisgaard-itis Sep 11, 2025
9fe331b
cleanup in api-server
bisgaard-itis Sep 11, 2025
fdff2d8
fix import error
bisgaard-itis Sep 11, 2025
b04a0af
fix storage mocks
bisgaard-itis Sep 11, 2025
87dedd9
fix mock
bisgaard-itis Sep 12, 2025
7ede81b
Merge branch 'master' into 8333-fix-filtering-bug-in-celery-tasks
bisgaard-itis Sep 12, 2025
36d92ff
Several minor corrections @pcrespov
bisgaard-itis Sep 12, 2025
c767367
document usage @pcrespov
bisgaard-itis Sep 12, 2025
7ede2aa
introduce backend agnostic wildcard
bisgaard-itis Sep 12, 2025
e5b2602
simplify wildcard usage
bisgaard-itis Sep 12, 2025
a906622
minor cleanup
bisgaard-itis Sep 12, 2025
6a9b046
enhance docs
bisgaard-itis Sep 12, 2025
f81dccd
cleanup @pcrespov
bisgaard-itis Sep 12, 2025
5603624
pylint
bisgaard-itis Sep 12, 2025
b8359d9
simplify wildcard usage
bisgaard-itis Sep 12, 2025
52897f5
minor fix
bisgaard-itis Sep 12, 2025
7c4fbba
Merge branch 'master' into 8333-fix-filtering-bug-in-celery-tasks
bisgaard-itis Sep 12, 2025
80817ec
model -> schema @pcrespov
bisgaard-itis Sep 12, 2025
ae8d703
fixes @pcrespov
bisgaard-itis Sep 12, 2025
c5bc5d2
Merge branch 'master' into 8333-fix-filtering-bug-in-celery-tasks
bisgaard-itis Sep 12, 2025
47ab209
Merge branch 'master' into 8333-fix-filtering-bug-in-celery-tasks
bisgaard-itis Sep 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 4 additions & 17 deletions packages/celery-library/src/celery_library/backends/_redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,9 @@

from models_library.progress_bar import ProgressReport
from pydantic import ValidationError
from servicelib.celery.models import (
Task,
TaskFilter,
TaskID,
TaskMetadata,
TaskUUID,
)
from servicelib.celery.models import Task, TaskFilter, TaskID, TaskMetadata
from servicelib.redis import RedisClientSDK

from ..utils import build_task_id_prefix

_CELERY_TASK_INFO_PREFIX: Final[str] = "celery-task-info-"
_CELERY_TASK_ID_KEY_ENCODING = "utf-8"
_CELERY_TASK_ID_KEY_SEPARATOR: Final[str] = ":"
Expand Down Expand Up @@ -83,17 +75,12 @@ async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
return None

async def list_tasks(self, task_filter: TaskFilter) -> list[Task]:
search_key = (
_CELERY_TASK_INFO_PREFIX
+ build_task_id_prefix(task_filter)
+ _CELERY_TASK_ID_KEY_SEPARATOR
)
search_key_len = len(search_key)
search_key = _CELERY_TASK_INFO_PREFIX + task_filter.task_id("*")

keys: list[str] = []
pipeline = self._redis_client_sdk.redis.pipeline()
async for key in self._redis_client_sdk.redis.scan_iter(
match=search_key + "*", count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
match=search_key, count=_CELERY_TASK_SCAN_COUNT_PER_BATCH
):
# fake redis (tests) returns bytes, real redis returns str
_key = (
Expand All @@ -115,7 +102,7 @@ async def list_tasks(self, task_filter: TaskFilter) -> list[Task]:
task_metadata = TaskMetadata.model_validate_json(raw_metadata)
tasks.append(
Task(
uuid=TaskUUID(key[search_key_len:]),
uuid=TaskFilter.task_uuid(key),
metadata=task_metadata,
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,8 @@ async def result(

@router.expose(reraise_if_error_type=(JobSchedulerError,))
async def list_jobs(
task_manager: TaskManager, filter_: str, job_filter: AsyncJobFilter
task_manager: TaskManager, job_filter: AsyncJobFilter
) -> list[AsyncJobGet]:
_ = filter_
assert task_manager # nosec
task_filter = TaskFilter.model_validate(job_filter.model_dump())
try:
Expand Down
12 changes: 5 additions & 7 deletions packages/celery-library/src/celery_library/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from servicelib.logging_utils import log_context
from settings_library.celery import CelerySettings

from .utils import build_task_id

_logger = logging.getLogger(__name__)


Expand All @@ -51,7 +49,7 @@ async def submit_task(
msg=f"Submit {task_metadata.name=}: {task_filter=} {task_params=}",
):
task_uuid = uuid4()
task_id = build_task_id(task_filter, task_uuid)
task_id = task_filter.task_id(task_uuid=task_uuid)
self._celery_app.send_task(
task_metadata.name,
task_id=task_id,
Expand Down Expand Up @@ -79,7 +77,7 @@ async def cancel_task(self, task_filter: TaskFilter, task_uuid: TaskUUID) -> Non
logging.DEBUG,
msg=f"task cancellation: {task_filter=} {task_uuid=}",
):
task_id = build_task_id(task_filter, task_uuid)
task_id = task_filter.task_id(task_uuid=task_uuid)
if not (await self.get_task_status(task_filter, task_uuid)).is_done:
await self._abort_task(task_id)
await self._task_info_store.remove_task(task_id)
Expand All @@ -96,7 +94,7 @@ async def get_task_result(
logging.DEBUG,
msg=f"Get task result: {task_filter=} {task_uuid=}",
):
task_id = build_task_id(task_filter, task_uuid)
task_id = task_filter.task_id(task_uuid=task_uuid)
async_result = self._celery_app.AsyncResult(task_id)
result = async_result.result
if async_result.ready():
Expand All @@ -110,7 +108,7 @@ async def _get_task_progress_report(
self, task_filter: TaskFilter, task_uuid: TaskUUID, task_state: TaskState
) -> ProgressReport:
if task_state in (TaskState.STARTED, TaskState.RETRY, TaskState.ABORTED):
task_id = build_task_id(task_filter, task_uuid)
task_id = task_filter.task_id(task_uuid=task_uuid)
progress = await self._task_info_store.get_task_progress(task_id)
if progress is not None:
return progress
Expand Down Expand Up @@ -139,7 +137,7 @@ async def get_task_status(
logging.DEBUG,
msg=f"Getting task status: {task_filter=} {task_uuid=}",
):
task_id = build_task_id(task_filter, task_uuid)
task_id = task_filter.task_id(task_uuid=task_uuid)
task_state = await self._get_task_celery_state(task_id)
return TaskStatus(
task_uuid=task_uuid,
Expand Down
18 changes: 0 additions & 18 deletions packages/celery-library/src/celery_library/utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,8 @@
from typing import Final

from celery import Celery # type: ignore[import-untyped]
from servicelib.celery.app_server import BaseAppServer
from servicelib.celery.models import TaskFilter, TaskID, TaskUUID

_APP_SERVER_KEY = "app_server"

_TASK_ID_KEY_DELIMITATOR: Final[str] = ":"


def build_task_id_prefix(task_filter: TaskFilter) -> str:
filter_dict = task_filter.model_dump()
return _TASK_ID_KEY_DELIMITATOR.join(
[f"{filter_dict[key]}" for key in sorted(filter_dict)]
)


def build_task_id(task_filter: TaskFilter, task_uuid: TaskUUID) -> TaskID:
return _TASK_ID_KEY_DELIMITATOR.join(
[build_task_id_prefix(task_filter), f"{task_uuid}"]
)


def get_app_server(app: Celery) -> BaseAppServer:
app_server = app.conf[_APP_SERVER_KEY]
Expand Down
2 changes: 0 additions & 2 deletions packages/celery-library/tests/unit/test_async_jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ async def test_async_jobs_workflow(
jobs = await async_jobs.list_jobs(
async_jobs_rabbitmq_rpc_client,
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
filter_="", # currently not used
job_filter=job_filter,
)
assert len(jobs) > 0
Expand Down Expand Up @@ -317,7 +316,6 @@ async def test_async_jobs_cancel(
jobs = await async_jobs.list_jobs(
async_jobs_rabbitmq_rpc_client,
rpc_namespace=ASYNC_JOBS_RPC_NAMESPACE,
filter_="", # currently not used
job_filter=job_filter,
)
assert async_job_get.job_id not in [job.job_id for job in jobs]
Expand Down
57 changes: 57 additions & 0 deletions packages/celery-library/tests/unit/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@
from celery_library.task_manager import CeleryTaskManager
from celery_library.utils import get_app_server
from common_library.errors_classes import OsparcErrorMixin
from faker import Faker
from models_library.progress_bar import ProgressReport
from pydantic import BaseModel
from servicelib.celery.models import (
WILDCARD,
TaskFilter,
TaskID,
TaskMetadata,
TaskState,
TaskUUID,
)
from servicelib.logging_utils import log_context
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed

_faker = Faker()

_logger = logging.getLogger(__name__)

pytest_simcore_core_services_selection = ["redis"]
Expand Down Expand Up @@ -203,3 +209,54 @@ async def test_listing_task_uuids_contains_submitted_task(

tasks = await celery_task_manager.list_tasks(task_filter)
assert any(task.uuid == task_uuid for task in tasks)


async def test_filtering_listing_tasks(
celery_task_manager: CeleryTaskManager,
):
class MyFilter(BaseModel):
user_id: int
product_name: str
client_app: str

user_id = 42
expected_task_uuids: set[TaskUUID] = set()

for _ in range(5):
myfilter = MyFilter(
user_id=user_id,
product_name=_faker.word(),
client_app=_faker.word(),
)
task_filter = TaskFilter.model_validate(myfilter.model_dump())
task_uuid = await celery_task_manager.submit_task(
TaskMetadata(
name=dreamer_task.__name__,
),
task_filter=task_filter,
)
expected_task_uuids.add(task_uuid)

for _ in range(3):
myfilter = MyFilter(
user_id=_faker.pyint(min_value=100, max_value=200),
product_name=_faker.word(),
client_app=_faker.word(),
)
task_filter = TaskFilter.model_validate(myfilter.model_dump())
await celery_task_manager.submit_task(
TaskMetadata(
name=dreamer_task.__name__,
),
task_filter=task_filter,
)

search_filter = MyFilter(
user_id=user_id,
product_name=WILDCARD,
client_app=WILDCARD,
)
tasks = await celery_task_manager.list_tasks(
TaskFilter.model_validate(search_filter.model_dump())
)
assert expected_task_uuids == {task.uuid for task in tasks}
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class AsyncJobFilter(AsyncJobFilterBase):

product_name: ProductName
user_id: UserID
client_name: Annotated[
client_name: Annotated[ # this is the name of the app which *submits* the async job. It is mainly used for filtering purposes
str,
StringConstraints(min_length=1, pattern=r"^[^\s]+$"),
]
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
AsyncJobGet,
)
from models_library.api_schemas_webserver.storage import PathToExport
from models_library.products import ProductName
from models_library.users import UserID
from pydantic import TypeAdapter, validate_call
from pytest_mock import MockType
from servicelib.rabbitmq._client_rpc import RabbitMQRPCClient
Expand All @@ -27,14 +25,12 @@ async def start_export_data(
self,
rabbitmq_rpc_client: RabbitMQRPCClient | MockType,
*,
user_id: UserID,
product_name: ProductName,
paths_to_export: list[PathToExport],
export_as: Literal["path", "download_link"],
job_filter: AsyncJobFilter,
) -> tuple[AsyncJobGet, AsyncJobFilter]:
assert rabbitmq_rpc_client
assert user_id
assert product_name
assert job_filter
assert paths_to_export
assert export_as

Expand Down
63 changes: 60 additions & 3 deletions packages/service-library/src/servicelib/celery/models.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,77 @@
import datetime
from enum import StrEnum
from typing import Annotated, Protocol, TypeAlias
from typing import Annotated, Any, Final, Literal, Protocol, Self, TypeAlias, TypeVar
from uuid import UUID

from models_library.progress_bar import ProgressReport
from pydantic import BaseModel, ConfigDict, StringConstraints
from pydantic import BaseModel, ConfigDict, StringConstraints, model_validator
from pydantic.config import JsonDict

T = TypeVar("T", bound=BaseModel)

TaskID: TypeAlias = str
TaskName: TypeAlias = Annotated[
str, StringConstraints(strip_whitespace=True, min_length=1)
]
TaskUUID: TypeAlias = UUID
_TASK_ID_KEY_DELIMITATOR: Final[str] = ":"
WILDCARD: Final[str] = "*"


class TaskFilter(BaseModel):
model_config = ConfigDict(extra="allow")

@model_validator(mode="after")
def _check_valid_filters(self) -> Self:
for key in self.model_dump().keys():
if WILDCARD in key or _TASK_ID_KEY_DELIMITATOR in key or "=" in key:
raise ValueError(f"Invalid filter key: '{key}'")
if (
_TASK_ID_KEY_DELIMITATOR in f"{getattr(self, key)}"
or "=" in f"{getattr(self, key)}"
):
raise ValueError(
f"Invalid filter value for key '{key}': '{getattr(self, key)}'"
)
return self

def _build_task_id_prefix(self) -> str:
filter_dict = self.model_dump()
return _TASK_ID_KEY_DELIMITATOR.join(
[f"{key}={filter_dict[key]}" for key in sorted(filter_dict)]
)

def task_id(self, task_uuid: TaskUUID | Literal["*"]) -> TaskID:
return _TASK_ID_KEY_DELIMITATOR.join(
[self._build_task_id_prefix(), f"task_uuid={task_uuid}"]
)

class TaskFilter(BaseModel): ...
@classmethod
def recreate_model(cls, task_id: TaskID, model: type[T]) -> T:
filter_dict = cls.recreate_data(task_id)
return model.model_validate(filter_dict)

@classmethod
def recreate_data(cls, task_id: TaskID) -> dict[str, Any]:
"""Recreates the filter data from a task_id string
Careful: does not validate types. For that use `recreate_model` instead
"""
try:
parts = task_id.split(_TASK_ID_KEY_DELIMITATOR)
return {
key: value
for part in parts[:-1]
if (key := part.split("=")[0]) and (value := part.split("=")[1])
}
except (IndexError, ValueError) as err:
raise ValueError(f"Invalid task_id format: {task_id}") from err

@classmethod
def task_uuid(cls, task_id: TaskID) -> TaskUUID:
try:
return UUID(task_id.split(_TASK_ID_KEY_DELIMITATOR)[-1].split("=")[1])
except (IndexError, ValueError) as err:
raise ValueError(f"Invalid task_id format: {task_id}") from err


class TaskState(StrEnum):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,11 @@ async def list_jobs(
rabbitmq_rpc_client: RabbitMQRPCClient,
*,
rpc_namespace: RPCNamespace,
filter_: str,
job_filter: AsyncJobFilter,
) -> list[AsyncJobGet]:
_result: list[AsyncJobGet] = await rabbitmq_rpc_client.request(
rpc_namespace,
TypeAdapter(RPCMethodName).validate_python("list_jobs"),
filter_=filter_,
job_filter=job_filter,
timeout_s=_DEFAULT_TIMEOUT_S,
)
Expand Down

This file was deleted.

Loading
Loading