Skip to content

Commit 97fa75b

Browse files
add tests
1 parent aba2a54 commit 97fa75b

File tree

3 files changed

+79
-32
lines changed

3 files changed

+79
-32
lines changed

packages/celery-library/tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,11 @@ async def mock_celery_app(celery_config: dict[str, Any]) -> Celery:
156156

157157

158158
@pytest.fixture
159-
async def celery_task_manager(
159+
async def task_manager(
160160
mock_celery_app: Celery,
161161
celery_settings: CelerySettings,
162162
use_in_memory_redis: RedisSettings,
163-
) -> AsyncIterator[CeleryTaskManager]:
163+
) -> AsyncIterator[TaskManager]:
164164
register_celery_types()
165165

166166
try:

packages/celery-library/tests/unit/test_async_jobs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,13 @@ async def async_job(task: Task, task_id: TaskID, action: Action, payload: Any) -
138138

139139
@pytest.fixture
140140
async def register_rpc_routes(
141-
async_jobs_rabbitmq_rpc_client: RabbitMQRPCClient, celery_task_manager: TaskManager
141+
async_jobs_rabbitmq_rpc_client: RabbitMQRPCClient, task_manager: TaskManager
142142
) -> None:
143143
await async_jobs_rabbitmq_rpc_client.register_router(
144-
_async_jobs.router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=celery_task_manager
144+
_async_jobs.router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=task_manager
145145
)
146146
await async_jobs_rabbitmq_rpc_client.register_router(
147-
router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=celery_task_manager
147+
router, ASYNC_JOBS_RPC_NAMESPACE, task_manager=task_manager
148148
)
149149

150150

packages/celery-library/tests/unit/test_task_manager.py

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,12 @@
2525
OwnerMetadata,
2626
TaskID,
2727
TaskState,
28+
TaskStatusEvent,
29+
TaskStatusValue,
2830
TaskUUID,
2931
Wildcard,
3032
)
33+
from servicelib.celery.task_manager import TaskManager
3134
from servicelib.logging_utils import log_context
3235
from tenacity import Retrying, retry_if_exception_type, stop_after_delay, wait_fixed
3336

@@ -89,24 +92,43 @@ async def dreamer_task(task: Task, task_id: TaskID) -> list[int]:
8992
return numbers
9093

9194

95+
async def event_publisher_task(task: Task, task_id: TaskID) -> None:
96+
"""Task that publishes custom events for testing event consumption."""
97+
from servicelib.celery.models import TaskDataEvent
98+
99+
task_manager = get_app_server(task.app).task_manager
100+
101+
data_event = TaskDataEvent(data={"message": "Processing started", "step": 1})
102+
await task_manager.publish_task_event(task_id, data_event)
103+
data_event = TaskDataEvent(data={"message": "Halfway done", "step": 2})
104+
await task_manager.publish_task_event(task_id, data_event)
105+
data_event = TaskDataEvent(data={"message": "Processing completed", "step": 3})
106+
await task_manager.publish_task_event(task_id, data_event)
107+
108+
await task_manager.publish_task_event(
109+
task_id, TaskStatusEvent(data=TaskStatusValue.SUCCESS)
110+
)
111+
112+
92113
@pytest.fixture
93114
def register_celery_tasks() -> Callable[[Celery], None]:
94115
def _(celery_app: Celery) -> None:
95116
register_task(celery_app, fake_file_processor)
96117
register_task(celery_app, failure_task)
97118
register_task(celery_app, dreamer_task)
119+
register_task(celery_app, event_publisher_task)
98120

99121
return _
100122

101123

102124
async def test_submitting_task_calling_async_function_results_with_success_state(
103-
celery_task_manager: CeleryTaskManager,
125+
task_manager: TaskManager,
104126
with_celery_worker: WorkController,
105127
):
106128

107129
owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner")
108130

109-
task_uuid = await celery_task_manager.submit_task(
131+
task_uuid = await task_manager.submit_task(
110132
ExecutionMetadata(
111133
name=fake_file_processor.__name__,
112134
),
@@ -120,27 +142,25 @@ async def test_submitting_task_calling_async_function_results_with_success_state
120142
stop=stop_after_delay(30),
121143
):
122144
with attempt:
123-
status = await celery_task_manager.get_task_status(
124-
owner_metadata, task_uuid
125-
)
145+
status = await task_manager.get_task_status(owner_metadata, task_uuid)
126146
assert status.task_state == TaskState.SUCCESS
127147

128148
assert (
129-
await celery_task_manager.get_task_status(owner_metadata, task_uuid)
149+
await task_manager.get_task_status(owner_metadata, task_uuid)
130150
).task_state == TaskState.SUCCESS
131151
assert (
132-
await celery_task_manager.get_task_result(owner_metadata, task_uuid)
152+
await task_manager.get_task_result(owner_metadata, task_uuid)
133153
) == "archive.zip"
134154

135155

136156
async def test_submitting_task_with_failure_results_with_error(
137-
celery_task_manager: CeleryTaskManager,
157+
task_manager: TaskManager,
138158
with_celery_worker: WorkController,
139159
):
140160

141161
owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner")
142162

143-
task_uuid = await celery_task_manager.submit_task(
163+
task_uuid = await task_manager.submit_task(
144164
ExecutionMetadata(
145165
name=failure_task.__name__,
146166
),
@@ -154,23 +174,21 @@ async def test_submitting_task_with_failure_results_with_error(
154174
):
155175

156176
with attempt:
157-
raw_result = await celery_task_manager.get_task_result(
158-
owner_metadata, task_uuid
159-
)
177+
raw_result = await task_manager.get_task_result(owner_metadata, task_uuid)
160178
assert isinstance(raw_result, TransferrableCeleryError)
161179

162-
raw_result = await celery_task_manager.get_task_result(owner_metadata, task_uuid)
180+
raw_result = await task_manager.get_task_result(owner_metadata, task_uuid)
163181
assert f"{raw_result}" == "Something strange happened: BOOM!"
164182

165183

166184
async def test_cancelling_a_running_task_aborts_and_deletes(
167-
celery_task_manager: CeleryTaskManager,
185+
task_manager: TaskManager,
168186
with_celery_worker: WorkController,
169187
):
170188

171189
owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner")
172190

173-
task_uuid = await celery_task_manager.submit_task(
191+
task_uuid = await task_manager.submit_task(
174192
ExecutionMetadata(
175193
name=dreamer_task.__name__,
176194
),
@@ -179,22 +197,22 @@ async def test_cancelling_a_running_task_aborts_and_deletes(
179197

180198
await asyncio.sleep(3.0)
181199

182-
await celery_task_manager.cancel_task(owner_metadata, task_uuid)
200+
await task_manager.cancel_task(owner_metadata, task_uuid)
183201

184202
with pytest.raises(TaskNotFoundError):
185-
await celery_task_manager.get_task_status(owner_metadata, task_uuid)
203+
await task_manager.get_task_status(owner_metadata, task_uuid)
186204

187-
assert task_uuid not in await celery_task_manager.list_tasks(owner_metadata)
205+
assert task_uuid not in await task_manager.list_tasks(owner_metadata)
188206

189207

190208
async def test_listing_task_uuids_contains_submitted_task(
191-
celery_task_manager: CeleryTaskManager,
209+
task_manager: CeleryTaskManager,
192210
with_celery_worker: WorkController,
193211
):
194212

195213
owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner")
196214

197-
task_uuid = await celery_task_manager.submit_task(
215+
task_uuid = await task_manager.submit_task(
198216
ExecutionMetadata(
199217
name=dreamer_task.__name__,
200218
),
@@ -207,15 +225,15 @@ async def test_listing_task_uuids_contains_submitted_task(
207225
stop=stop_after_delay(10),
208226
):
209227
with attempt:
210-
tasks = await celery_task_manager.list_tasks(owner_metadata)
228+
tasks = await task_manager.list_tasks(owner_metadata)
211229
assert any(task.uuid == task_uuid for task in tasks)
212230

213-
tasks = await celery_task_manager.list_tasks(owner_metadata)
231+
tasks = await task_manager.list_tasks(owner_metadata)
214232
assert any(task.uuid == task_uuid for task in tasks)
215233

216234

217235
async def test_filtering_listing_tasks(
218-
celery_task_manager: CeleryTaskManager,
236+
task_manager: TaskManager,
219237
with_celery_worker: WorkController,
220238
):
221239
class MyOwnerMetadata(OwnerMetadata):
@@ -232,7 +250,7 @@ class MyOwnerMetadata(OwnerMetadata):
232250
owner_metadata = MyOwnerMetadata(
233251
user_id=user_id, product_name=_faker.word(), owner=_owner
234252
)
235-
task_uuid = await celery_task_manager.submit_task(
253+
task_uuid = await task_manager.submit_task(
236254
ExecutionMetadata(
237255
name=dreamer_task.__name__,
238256
),
@@ -247,7 +265,7 @@ class MyOwnerMetadata(OwnerMetadata):
247265
product_name=_faker.word(),
248266
owner=_owner,
249267
)
250-
task_uuid = await celery_task_manager.submit_task(
268+
task_uuid = await task_manager.submit_task(
251269
ExecutionMetadata(
252270
name=dreamer_task.__name__,
253271
),
@@ -260,9 +278,38 @@ class MyOwnerMetadata(OwnerMetadata):
260278
product_name="*",
261279
owner=_owner,
262280
)
263-
tasks = await celery_task_manager.list_tasks(search_owner_metadata)
281+
tasks = await task_manager.list_tasks(search_owner_metadata)
264282
assert expected_task_uuids == {task.uuid for task in tasks}
265283
finally:
266284
# clean up all tasks. this should ideally be done in the fixture
267285
for task_uuid, owner_metadata in all_tasks:
268-
await celery_task_manager.cancel_task(owner_metadata, task_uuid)
286+
await task_manager.cancel_task(owner_metadata, task_uuid)
287+
288+
289+
async def test_consuming_task_events(
290+
task_manager: TaskManager,
291+
with_celery_worker: WorkController,
292+
):
293+
294+
owner_metadata = MyOwnerMetadata(user_id=42, owner="test-owner")
295+
296+
# Submit a task that publishes events
297+
task_uuid = await task_manager.submit_task(
298+
ExecutionMetadata(
299+
name=event_publisher_task.__name__,
300+
),
301+
owner_metadata=owner_metadata,
302+
)
303+
304+
async for _, event in task_manager.consume_task_events(
305+
owner_metadata=owner_metadata,
306+
task_uuid=task_uuid,
307+
):
308+
task_is_done = isinstance(event, TaskStatusEvent) and event.is_done()
309+
310+
if task_is_done:
311+
break
312+
313+
assert (
314+
await task_manager.get_task_status(owner_metadata, task_uuid)
315+
).task_state == TaskState.SUCCESS

0 commit comments

Comments
 (0)