Skip to content

Commit 9a461d5

Browse files
committed
further refinements
1 parent 1fbd307 commit 9a461d5

File tree

6 files changed

+40
-33
lines changed

6 files changed

+40
-33
lines changed

packages/celery-library/src/celery_library/rpc/_async_jobs.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ async def cancel(
3636
):
3737
assert task_manager # nosec
3838
assert job_filter # nosec
39+
task_filter = TaskFilter.from_async_job_filter(job_filter)
3940
try:
4041
await task_manager.cancel_task(
41-
task_filter=TaskFilter.model_validate(job_filter.model_dump()),
42+
task_filter=task_filter,
4243
task_uuid=job_id,
4344
)
4445
except CeleryError as exc:
@@ -52,9 +53,10 @@ async def status(
5253
assert task_manager # nosec
5354
assert job_filter # nosec
5455

56+
task_filter = TaskFilter.from_async_job_filter(job_filter)
5557
try:
5658
task_status = await task_manager.get_task_status(
57-
task_filter=TaskFilter.model_validate(job_filter.model_dump()),
59+
task_filter=task_filter,
5860
task_uuid=job_id,
5961
)
6062
except CeleryError as exc:
@@ -82,7 +84,7 @@ async def result(
8284
assert job_id # nosec
8385
assert job_filter # nosec
8486

85-
task_filter = TaskFilter.model_validate(job_filter.model_dump())
87+
task_filter = TaskFilter.from_async_job_filter(job_filter)
8688

8789
try:
8890
_status = await task_manager.get_task_status(
@@ -129,9 +131,10 @@ async def list_jobs(
129131
) -> list[AsyncJobGet]:
130132
_ = filter_
131133
assert task_manager # nosec
134+
task_filter = TaskFilter.from_async_job_filter(job_filter)
132135
try:
133136
tasks = await task_manager.list_tasks(
134-
task_filter=TaskFilter.model_validate(job_filter.model_dump()),
137+
task_filter=task_filter,
135138
)
136139
except CeleryError as exc:
137140
raise JobSchedulerError(exc=f"{exc}") from exc

packages/celery-library/src/celery_library/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
_TASK_ID_KEY_DELIMITATOR: Final[str] = ":"
1010

1111

12-
def build_task_id_prefix(task_context: TaskFilter) -> str:
13-
_dict = task_context.model_dump()
12+
def build_task_id_prefix(task_filter: TaskFilter) -> str:
13+
_dict = task_filter.model_dump()
1414
return _TASK_ID_KEY_DELIMITATOR.join([f"{_dict[key]}" for key in sorted(_dict)])
1515

1616

17-
def build_task_id(task_context: TaskFilter, task_uuid: TaskUUID) -> TaskID:
17+
def build_task_id(task_filter: TaskFilter, task_uuid: TaskUUID) -> TaskID:
1818
return _TASK_ID_KEY_DELIMITATOR.join(
19-
[build_task_id_prefix(task_context), f"{task_uuid}"]
19+
[build_task_id_prefix(task_filter), f"{task_uuid}"]
2020
)
2121

2222

packages/celery-library/tests/unit/test_async_jobs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def rpc_sync_job(
8282
task_manager: TaskManager, *, job_filter: AsyncJobFilter, **kwargs: Any
8383
) -> AsyncJobGet:
8484
task_name = sync_job.__name__
85-
task_filter = TaskFilter.model_validate(job_filter.model_dump())
85+
task_filter = TaskFilter.from_async_job_filter(job_filter)
8686
task_uuid = await task_manager.submit_task(
8787
TaskMetadata(name=task_name), task_filter=task_filter, **kwargs
8888
)
@@ -95,7 +95,7 @@ async def rpc_async_job(
9595
task_manager: TaskManager, *, job_filter: AsyncJobFilter, **kwargs: Any
9696
) -> AsyncJobGet:
9797
task_name = async_job.__name__
98-
task_filter = TaskFilter.model_validate(job_filter.model_dump())
98+
task_filter = TaskFilter.from_async_job_filter(job_filter)
9999
task_uuid = await task_manager.submit_task(
100100
TaskMetadata(name=task_name), task_filter=task_filter, **kwargs
101101
)

packages/celery-library/tests/unit/test_tasks.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,13 @@ def _(celery_app: Celery) -> None:
9393
async def test_submitting_task_calling_async_function_results_with_success_state(
9494
celery_task_manager: CeleryTaskManager,
9595
):
96-
task_context = TaskFilter(user_id=42)
96+
task_filter = TaskFilter(user_id=42)
9797

9898
task_uuid = await celery_task_manager.submit_task(
9999
TaskMetadata(
100100
name=fake_file_processor.__name__,
101101
),
102-
task_context=task_context,
102+
task_filter=task_filter,
103103
files=[f"file{n}" for n in range(5)],
104104
)
105105

@@ -109,27 +109,27 @@ async def test_submitting_task_calling_async_function_results_with_success_state
109109
stop=stop_after_delay(30),
110110
):
111111
with attempt:
112-
status = await celery_task_manager.get_task_status(task_context, task_uuid)
112+
status = await celery_task_manager.get_task_status(task_filter, task_uuid)
113113
assert status.task_state == TaskState.SUCCESS
114114

115115
assert (
116-
await celery_task_manager.get_task_status(task_context, task_uuid)
116+
await celery_task_manager.get_task_status(task_filter, task_uuid)
117117
).task_state == TaskState.SUCCESS
118118
assert (
119-
await celery_task_manager.get_task_result(task_context, task_uuid)
119+
await celery_task_manager.get_task_result(task_filter, task_uuid)
120120
) == "archive.zip"
121121

122122

123123
async def test_submitting_task_with_failure_results_with_error(
124124
celery_task_manager: CeleryTaskManager,
125125
):
126-
task_context = TaskFilter(user_id=42)
126+
task_filter = TaskFilter(user_id=42)
127127

128128
task_uuid = await celery_task_manager.submit_task(
129129
TaskMetadata(
130130
name=failure_task.__name__,
131131
),
132-
task_context=task_context,
132+
task_filter=task_filter,
133133
)
134134

135135
for attempt in Retrying(
@@ -140,58 +140,56 @@ async def test_submitting_task_with_failure_results_with_error(
140140

141141
with attempt:
142142
raw_result = await celery_task_manager.get_task_result(
143-
task_context, task_uuid
143+
task_filter, task_uuid
144144
)
145145
assert isinstance(raw_result, TransferrableCeleryError)
146146

147-
raw_result = await celery_task_manager.get_task_result(task_context, task_uuid)
147+
raw_result = await celery_task_manager.get_task_result(task_filter, task_uuid)
148148
assert f"{raw_result}" == "Something strange happened: BOOM!"
149149

150150

151151
async def test_cancelling_a_running_task_aborts_and_deletes(
152152
celery_task_manager: CeleryTaskManager,
153153
):
154-
task_context = TaskFilter(user_id=42)
154+
task_filter = TaskFilter(user_id=42)
155155

156156
task_uuid = await celery_task_manager.submit_task(
157157
TaskMetadata(
158158
name=dreamer_task.__name__,
159159
),
160-
task_context=task_context,
160+
task_filter=task_filter,
161161
)
162162

163163
await asyncio.sleep(3.0)
164164

165-
await celery_task_manager.cancel_task(task_context, task_uuid)
165+
await celery_task_manager.cancel_task(task_filter, task_uuid)
166166

167167
for attempt in Retrying(
168168
retry=retry_if_exception_type(AssertionError),
169169
wait=wait_fixed(1),
170170
stop=stop_after_delay(30),
171171
):
172172
with attempt:
173-
progress = await celery_task_manager.get_task_status(
174-
task_context, task_uuid
175-
)
173+
progress = await celery_task_manager.get_task_status(task_filter, task_uuid)
176174
assert progress.task_state == TaskState.ABORTED
177175

178176
assert (
179-
await celery_task_manager.get_task_status(task_context, task_uuid)
177+
await celery_task_manager.get_task_status(task_filter, task_uuid)
180178
).task_state == TaskState.ABORTED
181179

182-
assert task_uuid not in await celery_task_manager.list_tasks(task_context)
180+
assert task_uuid not in await celery_task_manager.list_tasks(task_filter)
183181

184182

185183
async def test_listing_task_uuids_contains_submitted_task(
186184
celery_task_manager: CeleryTaskManager,
187185
):
188-
task_context = TaskFilter(user_id=42)
186+
task_filter = TaskFilter(user_id=42)
189187

190188
task_uuid = await celery_task_manager.submit_task(
191189
TaskMetadata(
192190
name=dreamer_task.__name__,
193191
),
194-
task_context=task_context,
192+
task_filter=task_filter,
195193
)
196194

197195
for attempt in Retrying(
@@ -200,8 +198,8 @@ async def test_listing_task_uuids_contains_submitted_task(
200198
stop=stop_after_delay(10),
201199
):
202200
with attempt:
203-
tasks = await celery_task_manager.list_tasks(task_context)
201+
tasks = await celery_task_manager.list_tasks(task_filter)
204202
assert any(task.uuid == task_uuid for task in tasks)
205203

206-
tasks = await celery_task_manager.list_tasks(task_context)
204+
tasks = await celery_task_manager.list_tasks(task_filter)
207205
assert any(task.uuid == task_uuid for task in tasks)

packages/service-library/src/servicelib/celery/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Annotated, Protocol, TypeAlias
44
from uuid import UUID
55

6+
from models_library.api_schemas_rpc_async_jobs.async_jobs import AsyncJobFilter
67
from models_library.progress_bar import ProgressReport
78
from pydantic import BaseModel, ConfigDict, StringConstraints
89

@@ -16,6 +17,10 @@
1617
class TaskFilter(BaseModel):
1718
model_config = ConfigDict(extra="forbid")
1819

20+
@classmethod
21+
def from_async_job_filter(cls, async_job_filter: AsyncJobFilter) -> "TaskFilter":
22+
cls.model_validate(async_job_filter.model_dump())
23+
1924

2025
class TaskState(StrEnum):
2126
PENDING = "PENDING"

services/storage/src/simcore_service_storage/api/rpc/_simcore_s3.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
)
55
from models_library.api_schemas_storage.storage_schemas import FoldersBody
66
from models_library.api_schemas_webserver.storage import PathToExport
7-
from servicelib.celery.models import TaskMetadata, TasksQueue
7+
from servicelib.celery.models import TaskFilter, TaskMetadata, TasksQueue
88
from servicelib.celery.task_manager import TaskManager
99
from servicelib.rabbitmq import RPCRouter
1010

@@ -20,11 +20,12 @@ async def copy_folders_from_project(
2020
body: FoldersBody,
2121
) -> AsyncJobGet:
2222
task_name = deep_copy_files_from_project.__name__
23+
task_filter = TaskFilter.model_validate(job_filter.model_dump())
2324
task_uuid = await task_manager.submit_task(
2425
task_metadata=TaskMetadata(
2526
name=task_name,
2627
),
27-
task_filter=job_filter,
28+
task_filter=task_filter,
2829
user_id=job_filter.user_id,
2930
body=body,
3031
)

0 commit comments

Comments
 (0)