|
3 | 3 |
|
4 | 4 | import asyncio |
5 | 5 | import pickle |
6 | | -from collections.abc import Callable |
| 6 | +from collections.abc import Awaitable, Callable |
7 | 7 | from datetime import timedelta |
8 | 8 | from enum import Enum |
9 | 9 | from typing import Any |
@@ -108,12 +108,19 @@ async def async_job(task: Task, task_id: TaskID, action: Action, payload: Any) - |
108 | 108 |
|
109 | 109 |
|
110 | 110 | @pytest.fixture |
111 | | -async def rpc_client( |
| 111 | +async def register_routes( |
112 | 112 | initialized_fast_api: FastAPI, rpc_namespace: RPCNamespace |
113 | | -) -> RabbitMQRPCClient: |
| 113 | +) -> None: |
114 | 114 | client = initialized_fast_api.state.rabbitmq_rpc_client |
115 | 115 | assert isinstance(client, RabbitMQRPCClient) |
116 | 116 | await client.register_router(router, rpc_namespace, initialized_fast_api) |
| 117 | + |
| 118 | + |
| 119 | +@pytest.fixture |
| 120 | +async def rpc_client( |
| 121 | + rabbitmq_rpc_client: Callable[[str], Awaitable[RabbitMQRPCClient]], |
| 122 | +) -> RabbitMQRPCClient: |
| 123 | + client = await rabbitmq_rpc_client("celery_test_client") |
117 | 124 | return client |
118 | 125 |
|
119 | 126 |
|
@@ -204,9 +211,10 @@ async def _wait_for_job( |
204 | 211 | ], |
205 | 212 | ) |
206 | 213 | async def test_async_jobs_workflow( |
| 214 | + register_routes, |
207 | 215 | rpc_client: RabbitMQRPCClient, |
208 | 216 | rpc_namespace: RPCNamespace, |
209 | | - with_storage_celery_worker: CeleryTaskWorker, |
| 217 | + with_celery_worker: CeleryTaskWorker, |
210 | 218 | user_id: UserID, |
211 | 219 | product_name: ProductName, |
212 | 220 | exposed_rpc_start: str, |
@@ -257,7 +265,7 @@ async def test_async_jobs_cancel( |
257 | 265 | register_rpc_routes: None, |
258 | 266 | rpc_namespace: RPCNamespace, |
259 | 267 | storage_rabbitmq_rpc_client: RabbitMQRPCClient, |
260 | | - with_storage_celery_worker: CeleryTaskWorker, |
| 268 | + with_celery_worker: CeleryTaskWorker, |
261 | 269 | user_id: UserID, |
262 | 270 | product_name: ProductName, |
263 | 271 | exposed_rpc_start: str, |
@@ -325,7 +333,7 @@ async def test_async_jobs_raises( |
325 | 333 | register_rpc_routes: None, |
326 | 334 | rpc_namespace: RPCNamespace, |
327 | 335 | storage_rabbitmq_rpc_client: RabbitMQRPCClient, |
328 | | - with_storage_celery_worker: CeleryTaskWorker, |
| 336 | + with_celery_worker: CeleryTaskWorker, |
329 | 337 | user_id: UserID, |
330 | 338 | product_name: ProductName, |
331 | 339 | exposed_rpc_start: str, |
|
0 commit comments