Skip to content

Commit 687d848

Browse files
tests: refactor
1 parent e8a4510 commit 687d848

File tree

6 files changed

+248
-42
lines changed

6 files changed

+248
-42
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from dataclasses import dataclass
2+
from datetime import timedelta
3+
4+
from models_library.progress_bar import ProgressReport
5+
6+
from ..models import (
7+
Task,
8+
TaskContext,
9+
TaskID,
10+
TaskMetadata,
11+
TaskUUID,
12+
build_task_id_prefix,
13+
)
14+
15+
16+
@dataclass
17+
class MemoryTaskInfo:
18+
metadata: TaskMetadata
19+
progress: ProgressReport
20+
21+
22+
class MemoryTaskInfoStore:
23+
def __init__(self) -> None:
24+
self._tasks: dict[TaskID, MemoryTaskInfo] = {}
25+
26+
async def create_task(
27+
self,
28+
task_id: TaskID,
29+
task_metadata: TaskMetadata,
30+
expiry: timedelta,
31+
) -> None:
32+
self._tasks[task_id] = MemoryTaskInfo(
33+
metadata=task_metadata,
34+
progress=ProgressReport(actual_value=0.0),
35+
)
36+
37+
async def exists_task(self, task_id: TaskID) -> bool:
38+
return task_id in self._tasks
39+
40+
async def get_task_metadata(self, task_id: TaskID) -> TaskMetadata | None:
41+
task_info = self._tasks.get(task_id)
42+
if task_info is None:
43+
return None
44+
return task_info.metadata
45+
46+
async def get_task_progress(self, task_id: TaskID) -> ProgressReport | None:
47+
task_info = self._tasks.get(task_id)
48+
if task_info is None:
49+
return None
50+
return task_info.progress
51+
52+
async def list_tasks(self, task_context: TaskContext) -> list[Task]:
53+
tasks = []
54+
task_id_prefix = build_task_id_prefix(task_context)
55+
for task_id, task_info in self._tasks.items():
56+
if task_id.startswith(task_id_prefix):
57+
tasks.append(
58+
Task(
59+
uuid=TaskUUID(task_id[len(task_id_prefix) + 1 :]),
60+
metadata=task_info.metadata,
61+
)
62+
)
63+
return tasks
64+
65+
async def remove_task(self, task_id: TaskID) -> None:
66+
self._tasks.pop(task_id, None)
67+
68+
async def set_task_progress(self, task_id: TaskID, report: ProgressReport) -> None:
69+
task_info = self._tasks.get(task_id)
70+
if task_info is not None:
71+
task_info.progress = report

packages/celery-library/src/celery_library/signals.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from celery import Celery # type: ignore[import-untyped]
66
from celery.worker.worker import WorkController # type: ignore[import-untyped]
7-
from servicelib.base_app_server import BaseAppServer
7+
from servicelib.base_app_server import STARTUP_TIMEOUT, BaseAppServer
88
from servicelib.logging_utils import log_context
99
from settings_library.celery import CelerySettings
1010

@@ -30,6 +30,8 @@ def _init(startup_complete_event: threading.Event) -> None:
3030
loop = asyncio.new_event_loop()
3131
asyncio.set_event_loop(loop)
3232

33+
shutdown_event = asyncio.Event()
34+
3335
app_server.event_loop = loop
3436

3537
async def _setup():
@@ -46,7 +48,9 @@ async def _setup():
4648
)
4749

4850
loop.run_until_complete(_setup())
49-
loop.run_until_complete(app_server.startup(startup_complete_event))
51+
loop.run_until_complete(
52+
app_server.startup(startup_complete_event, shutdown_event)
53+
)
5054

5155
thread = threading.Thread(
5256
group=None,
@@ -57,7 +61,7 @@ async def _setup():
5761
)
5862
thread.start()
5963

60-
startup_complete_event.wait()
64+
startup_complete_event.wait(STARTUP_TIMEOUT * 1.1)
6165

6266

6367
def on_worker_shutdown(sender, **_kwargs) -> None:
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import asyncio
2+
import datetime
3+
from collections.abc import AsyncIterator, Callable
4+
from functools import partial
5+
from threading import Event
6+
from typing import Any
7+
8+
import pytest
9+
from celery import Celery # type: ignore[import-untyped]
10+
from celery.contrib.testing.worker import TestWorkController, start_worker
11+
from celery.signals import worker_init, worker_shutdown
12+
from celery.worker.worker import WorkController
13+
from celery_library.backends._memory import MemoryTaskInfoStore
14+
from celery_library.signals import on_worker_init, on_worker_shutdown
15+
from celery_library.utils import CeleryTaskManager, get_task_manager
16+
from pytest_mock import MockerFixture
17+
from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict
18+
from pytest_simcore.helpers.typing_env import EnvVarsDict
19+
from servicelib.base_app_server import BaseAppServer
20+
from settings_library.celery import CelerySettings
21+
22+
pytest_plugins = [
23+
"pytest_simcore.environment_configs",
24+
"pytest_simcore.repository_paths",
25+
]
26+
27+
28+
class FakeAppServer(BaseAppServer):
29+
def __init__(self):
30+
self._shutdown_event: asyncio.Event | None = None
31+
32+
async def startup(
33+
self, completed_event: Event, shutdown_event: asyncio.Event
34+
) -> None:
35+
completed_event.set()
36+
await shutdown_event.wait()
37+
38+
async def shutdown(self) -> None:
39+
if self._shutdown_event is not None:
40+
self._shutdown_event.set()
41+
42+
43+
@pytest.fixture
44+
def celery_config() -> dict[str, Any]:
45+
return {
46+
"broker_connection_retry_on_startup": True,
47+
"broker_url": "memory://localhost//",
48+
"result_backend": "cache+memory://localhost//",
49+
"result_expires": datetime.timedelta(days=7),
50+
"result_extended": True,
51+
"pool": "threads",
52+
"task_default_queue": "default",
53+
"task_send_sent_event": True,
54+
"task_track_started": True,
55+
"worker_send_task_events": True,
56+
}
57+
58+
59+
@pytest.fixture
60+
def register_celery_tasks() -> Callable[[Celery], None]:
61+
"""override if tasks are needed"""
62+
63+
def _(celery_app: Celery) -> None: ...
64+
65+
return _
66+
67+
68+
@pytest.fixture
69+
def app_environment(
70+
monkeypatch: pytest.MonkeyPatch,
71+
env_devel_dict: EnvVarsDict,
72+
) -> EnvVarsDict:
73+
return setenvs_from_dict(
74+
monkeypatch,
75+
{
76+
**env_devel_dict,
77+
},
78+
)
79+
80+
81+
@pytest.fixture
82+
def celery_settings(
83+
app_environment: EnvVarsDict,
84+
) -> CelerySettings:
85+
return CelerySettings.create_from_envs()
86+
87+
88+
@pytest.fixture
89+
async def with_storage_celery_worker(
90+
celery_app: Celery,
91+
celery_settings: CelerySettings,
92+
register_celery_tasks: Callable[[Celery], None],
93+
mocker: MockerFixture,
94+
) -> AsyncIterator[TestWorkController]:
95+
mocker.patch(
96+
"celery_library.signals.create_task_manager",
97+
return_value=CeleryTaskManager(
98+
celery_app, celery_settings, MemoryTaskInfoStore()
99+
),
100+
)
101+
102+
def _on_worker_init_wrapper(sender: WorkController, **_kwargs):
103+
return partial(on_worker_init, FakeAppServer(), celery_settings)(
104+
sender, **_kwargs
105+
)
106+
107+
worker_init.connect(_on_worker_init_wrapper)
108+
worker_shutdown.connect(on_worker_shutdown)
109+
110+
register_celery_tasks(celery_app)
111+
112+
with start_worker(
113+
celery_app,
114+
pool="threads",
115+
concurrency=1,
116+
loglevel="info",
117+
perform_ping_check=False,
118+
queues="default",
119+
) as worker:
120+
yield worker
121+
122+
123+
@pytest.fixture
124+
def celery_task_manager(
125+
with_storage_celery_worker: TestWorkController,
126+
) -> CeleryTaskManager:
127+
assert with_storage_celery_worker.app # nosec
128+
assert isinstance(with_storage_celery_worker.app, Celery) # nosec
129+
130+
return get_task_manager(with_storage_celery_worker.app)

services/storage/tests/unit/test_modules_celery.py renamed to packages/celery-library/tests/unit/test_tasks.py

Lines changed: 23 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import pytest
1414
from celery import Celery, Task
1515
from celery.contrib.abortable import AbortableTask
16-
from celery.contrib.testing.worker import TestWorkController
1716
from celery_library.errors import TransferrableCeleryError
1817
from celery_library.models import (
1918
TaskContext,
@@ -28,25 +27,12 @@
2827
from celery_library.task_manager import CeleryTaskManager
2928
from celery_library.utils import get_app_server, get_task_manager
3029
from common_library.errors_classes import OsparcErrorMixin
31-
from fastapi import FastAPI
3230
from models_library.progress_bar import ProgressReport
3331
from servicelib.logging_utils import log_context
34-
from simcore_service_storage.modules.celery import get_task_manager_from_app
3532
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed
3633

3734
_logger = logging.getLogger(__name__)
3835

39-
pytest_simcore_core_services_selection = ["postgres", "rabbit"]
40-
pytest_simcore_ops_services_selection = []
41-
42-
43-
@pytest.fixture
44-
def celery_client(
45-
initialized_app: FastAPI,
46-
with_storage_celery_worker: TestWorkController,
47-
) -> CeleryTaskManager:
48-
return get_task_manager_from_app(initialized_app)
49-
5036

5137
async def _fake_file_processor(
5238
celery_app: Celery, task_name: str, task_id: str, files: list[str]
@@ -110,11 +96,11 @@ def _(celery_app: Celery) -> None:
11096

11197

11298
async def test_submitting_task_calling_async_function_results_with_success_state(
113-
celery_client: CeleryTaskManager,
99+
celery_task_manager: CeleryTaskManager,
114100
):
115101
task_context = TaskContext(user_id=42)
116102

117-
task_uuid = await celery_client.submit_task(
103+
task_uuid = await celery_task_manager.submit_task(
118104
TaskMetadata(
119105
name=fake_file_processor.__name__,
120106
),
@@ -128,23 +114,23 @@ async def test_submitting_task_calling_async_function_results_with_success_state
128114
stop=stop_after_delay(30),
129115
):
130116
with attempt:
131-
status = await celery_client.get_task_status(task_context, task_uuid)
117+
status = await celery_task_manager.get_task_status(task_context, task_uuid)
132118
assert status.task_state == TaskState.SUCCESS
133119

134120
assert (
135-
await celery_client.get_task_status(task_context, task_uuid)
121+
await celery_task_manager.get_task_status(task_context, task_uuid)
136122
).task_state == TaskState.SUCCESS
137123
assert (
138-
await celery_client.get_task_result(task_context, task_uuid)
124+
await celery_task_manager.get_task_result(task_context, task_uuid)
139125
) == "archive.zip"
140126

141127

142128
async def test_submitting_task_with_failure_results_with_error(
143-
celery_client: CeleryTaskManager,
129+
celery_task_manager: CeleryTaskManager,
144130
):
145131
task_context = TaskContext(user_id=42)
146132

147-
task_uuid = await celery_client.submit_task(
133+
task_uuid = await celery_task_manager.submit_task(
148134
TaskMetadata(
149135
name=failure_task.__name__,
150136
),
@@ -158,49 +144,53 @@ async def test_submitting_task_with_failure_results_with_error(
158144
):
159145

160146
with attempt:
161-
raw_result = await celery_client.get_task_result(task_context, task_uuid)
147+
raw_result = await celery_task_manager.get_task_result(
148+
task_context, task_uuid
149+
)
162150
assert isinstance(raw_result, TransferrableCeleryError)
163151

164-
raw_result = await celery_client.get_task_result(task_context, task_uuid)
152+
raw_result = await celery_task_manager.get_task_result(task_context, task_uuid)
165153
assert f"{raw_result}" == "Something strange happened: BOOM!"
166154

167155

168156
async def test_cancelling_a_running_task_aborts_and_deletes(
169-
celery_client: CeleryTaskManager,
157+
celery_task_manager: CeleryTaskManager,
170158
):
171159
task_context = TaskContext(user_id=42)
172160

173-
task_uuid = await celery_client.submit_task(
161+
task_uuid = await celery_task_manager.submit_task(
174162
TaskMetadata(
175163
name=dreamer_task.__name__,
176164
),
177165
task_context=task_context,
178166
)
179167

180-
await celery_client.cancel_task(task_context, task_uuid)
168+
await celery_task_manager.cancel_task(task_context, task_uuid)
181169

182170
for attempt in Retrying(
183171
retry=retry_if_exception_type(AssertionError),
184172
wait=wait_fixed(1),
185173
stop=stop_after_delay(30),
186174
):
187175
with attempt:
188-
progress = await celery_client.get_task_status(task_context, task_uuid)
176+
progress = await celery_task_manager.get_task_status(
177+
task_context, task_uuid
178+
)
189179
assert progress.task_state == TaskState.ABORTED
190180

191181
assert (
192-
await celery_client.get_task_status(task_context, task_uuid)
182+
await celery_task_manager.get_task_status(task_context, task_uuid)
193183
).task_state == TaskState.ABORTED
194184

195-
assert task_uuid not in await celery_client.list_tasks(task_context)
185+
assert task_uuid not in await celery_task_manager.list_tasks(task_context)
196186

197187

198188
async def test_listing_task_uuids_contains_submitted_task(
199-
celery_client: CeleryTaskManager,
189+
celery_task_manager: CeleryTaskManager,
200190
):
201191
task_context = TaskContext(user_id=42)
202192

203-
task_uuid = await celery_client.submit_task(
193+
task_uuid = await celery_task_manager.submit_task(
204194
TaskMetadata(
205195
name=dreamer_task.__name__,
206196
),
@@ -213,10 +203,10 @@ async def test_listing_task_uuids_contains_submitted_task(
213203
stop=stop_after_delay(10),
214204
):
215205
with attempt:
216-
tasks = await celery_client.list_tasks(task_context)
206+
tasks = await celery_task_manager.list_tasks(task_context)
217207
assert len(tasks) == 1
218208
assert task_uuid == tasks[0].uuid
219209

220-
tasks = await celery_client.list_tasks(task_context)
210+
tasks = await celery_task_manager.list_tasks(task_context)
221211
assert len(tasks) == 1
222212
assert task_uuid == tasks[0].uuid

0 commit comments

Comments
 (0)