Skip to content

Commit b8359d9

Browse files
committed
simplify wildcard usage
1 parent 5603624 commit b8359d9

File tree

2 files changed

+9
-21
lines changed

2 files changed

+9
-21
lines changed

packages/celery-library/tests/unit/test_tasks.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from common_library.errors_classes import OsparcErrorMixin
2121
from faker import Faker
2222
from models_library.progress_bar import ProgressReport
23-
from pydantic import BaseModel
2423
from servicelib.celery.models import (
2524
TaskFilter,
2625
TaskID,
@@ -214,7 +213,7 @@ async def test_listing_task_uuids_contains_submitted_task(
214213
async def test_filtering_listing_tasks(
215214
celery_task_manager: CeleryTaskManager,
216215
):
217-
class MyFilter(BaseModel):
216+
class MyFilter(TaskFilter):
218217
user_id: int
219218
product_name: str | Wildcard
220219
client_app: str | Wildcard
@@ -223,12 +222,11 @@ class MyFilter(BaseModel):
223222
expected_task_uuids: set[TaskUUID] = set()
224223

225224
for _ in range(5):
226-
myfilter = MyFilter(
225+
task_filter = MyFilter(
227226
user_id=user_id,
228227
product_name=_faker.word(),
229228
client_app=_faker.word(),
230229
)
231-
task_filter = TaskFilter.model_validate(myfilter.model_dump())
232230
task_uuid = await celery_task_manager.submit_task(
233231
TaskMetadata(
234232
name=dreamer_task.__name__,
@@ -238,12 +236,11 @@ class MyFilter(BaseModel):
238236
expected_task_uuids.add(task_uuid)
239237

240238
for _ in range(3):
241-
myfilter = MyFilter(
239+
task_filter = MyFilter(
242240
user_id=_faker.pyint(min_value=100, max_value=200),
243241
product_name=_faker.word(),
244242
client_app=_faker.word(),
245243
)
246-
task_filter = TaskFilter.model_validate(myfilter.model_dump())
247244
await celery_task_manager.submit_task(
248245
TaskMetadata(
249246
name=dreamer_task.__name__,
@@ -256,7 +253,5 @@ class MyFilter(BaseModel):
256253
product_name=Wildcard(),
257254
client_app=Wildcard(),
258255
)
259-
tasks = await celery_task_manager.list_tasks(
260-
TaskFilter.model_validate(search_filter.model_dump())
261-
)
256+
tasks = await celery_task_manager.list_tasks(search_filter)
262257
assert expected_task_uuids == {task.uuid for task in tasks}

packages/service-library/src/servicelib/celery/models.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,9 @@
1919
_FORBIDDEN_CHARS = (_WILDCARD, _TASK_ID_KEY_DELIMITATOR, "=")
2020

2121

22-
class Wildcard: ...
23-
24-
25-
def _replace_wildcard(value: Any) -> str:
26-
if isinstance(value, Wildcard):
22+
class Wildcard:
23+
def __str__(self) -> str:
2724
return _WILDCARD
28-
return f"{value}"
2925

3026

3127
class TaskFilter(BaseModel):
@@ -47,7 +43,7 @@ class MyTaskFilter(TaskFilter):
4743
4844
"""
4945

50-
model_config = ConfigDict(extra="allow")
46+
model_config = ConfigDict(extra="allow", arbitrary_types_allowed=True)
5147

5248
@model_validator(mode="after")
5349
def _check_valid_filters(self) -> Self:
@@ -63,17 +59,14 @@ def _check_valid_filters(self) -> Self:
6359
def _build_task_id_prefix(self) -> str:
6460
filter_dict = self.model_dump()
6561
return _TASK_ID_KEY_DELIMITATOR.join(
66-
[
67-
f"{key}={_replace_wildcard(filter_dict[key])}"
68-
for key in sorted(filter_dict)
69-
]
62+
[f"{key}={filter_dict[key]}" for key in sorted(filter_dict)]
7063
)
7164

7265
def get_task_id(self, task_uuid: TaskUUID | Wildcard) -> TaskID:
7366
return _TASK_ID_KEY_DELIMITATOR.join(
7467
[
7568
self._build_task_id_prefix(),
76-
f"task_uuid={_replace_wildcard(task_uuid)}",
69+
f"task_uuid={task_uuid}",
7770
]
7871
)
7972

0 commit comments

Comments
 (0)