Skip to content

Commit 614df7d

Browse files
use hooks
1 parent 6e1fc9f commit 614df7d

File tree

5 files changed

+33
-36
lines changed

5 files changed

+33
-36
lines changed

packages/celery-library/requirements/_test.in

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# testing
1212
coverage
1313
faker
14-
fastapi
1514
httpx
1615
pint
1716
pytest

packages/celery-library/tests/conftest.py

Lines changed: 11 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
# pylint: disable=redefined-outer-name
22
# pylint: disable=unused-argument
33

4-
import asyncio
54
import datetime
65
from collections.abc import AsyncIterator, Callable
76
from functools import partial
8-
from threading import Event
97
from typing import Any
108

119
import pytest
@@ -32,19 +30,11 @@
3230

3331

3432
class FakeAppServer(BaseAppServer):
35-
def __init__(self):
36-
self._shutdown_event: asyncio.Event | None = None
33+
async def on_startup(self) -> None:
34+
pass
3735

38-
async def startup(
39-
self, completed_event: Event, shutdown_event: asyncio.Event
40-
) -> None:
41-
self._shutdown_event = shutdown_event
42-
completed_event.set()
43-
await self._shutdown_event.wait()
44-
45-
async def shutdown(self) -> None:
46-
if self._shutdown_event is not None:
47-
self._shutdown_event.set()
36+
async def on_shutdown(self) -> None:
37+
pass
4838

4939

5040
@pytest.fixture
@@ -80,13 +70,17 @@ def celery_settings(
8070
return CelerySettings.create_from_envs()
8171

8272

73+
@pytest.fixture
74+
def app_server() -> BaseAppServer:
75+
return FakeAppServer()
76+
77+
8378
@pytest.fixture(scope="session")
8479
def celery_config() -> dict[str, Any]:
8580
return {
8681
"broker_connection_retry_on_startup": True,
8782
"broker_url": "memory://localhost//",
8883
"result_backend": "cache+memory://localhost//",
89-
# "result_backend": celery_settings.CELERY_REDIS_RESULT_BACKEND.build_redis_dsn(RedisDatabase.CELERY_TASKS),
9084
"result_expires": datetime.timedelta(days=7),
9185
"result_extended": True,
9286
"pool": "threads",
@@ -100,13 +94,12 @@ def celery_config() -> dict[str, Any]:
10094
@pytest.fixture
10195
async def with_storage_celery_worker(
10296
celery_app: Celery,
97+
app_server: BaseAppServer,
10398
celery_settings: CelerySettings,
10499
register_celery_tasks: Callable[[Celery], None],
105100
) -> AsyncIterator[TestWorkController]:
106101
def _on_worker_init_wrapper(sender: WorkController, **_kwargs):
107-
return partial(on_worker_init, FakeAppServer(), celery_settings)(
108-
sender, **_kwargs
109-
)
102+
return partial(on_worker_init, app_server, celery_settings)(sender, **_kwargs)
110103

111104
worker_init.connect(_on_worker_init_wrapper)
112105
worker_shutdown.connect(on_worker_shutdown)

packages/celery-library/tests/unit/test_errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
AccessRightError(user_id=1, file_id="a/path/to/a/file.txt", location_id=0),
1414
],
1515
)
16-
def test_workflow(original_error: Exception):
16+
def test_error(original_error: Exception):
1717
try:
1818
raise original_error # noqa: TRY301
1919
except Exception as e: # pylint: disable=broad-exception-caught

packages/service-library/src/servicelib/celery/app_server.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818

1919
class BaseAppServer(ABC):
20+
def __init__(self) -> None:
21+
self._shutdown_event: asyncio.Event | None = None
22+
2023
@property
2124
def fastapi_app(self) -> "FastAPI":
2225
raise NotImplementedError
@@ -34,11 +37,23 @@ def event_loop(self, loop: AbstractEventLoop) -> None:
3437
self._event_loop = loop
3538

3639
@abstractmethod
40+
async def on_startup(self) -> None:
41+
raise NotImplementedError
42+
3743
async def startup(
3844
self, completed_event: threading.Event, shutdown_event: asyncio.Event
3945
) -> None:
40-
pass
46+
self._shutdown_event = shutdown_event
47+
completed_event.set()
48+
await self.on_startup()
49+
await self._shutdown_event.wait()
4150

4251
@abstractmethod
43-
async def shutdown(self):
44-
pass
52+
async def on_shutdown(self) -> None:
53+
raise NotImplementedError
54+
55+
async def shutdown(self) -> None:
56+
if self._shutdown_event is not None:
57+
self._shutdown_event.set()
58+
59+
await self.on_shutdown()
Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import asyncio
2-
import threading
31
from datetime import timedelta
42
from typing import Final
53

@@ -14,32 +12,24 @@
1412

1513
class FastAPIAppServer(BaseAppServer):
1614
def __init__(self, app: FastAPI):
15+
super().__init__()
1716
self._app = app
1817
self._lifespan_manager: LifespanManager | None = None
19-
self._shutdown_event: asyncio.Event | None = None
2018

2119
@property
2220
def fastapi_app(self) -> FastAPI:
2321
assert isinstance(self._app, FastAPI) # nosec
2422
return self._app
2523

26-
async def startup(
27-
self, completed_event: threading.Event, shutdown_event: asyncio.Event
28-
):
24+
async def on_startup(self) -> None:
2925
self._lifespan_manager = LifespanManager(
3026
self.fastapi_app,
3127
startup_timeout=_STARTUP_TIMEOUT,
3228
shutdown_timeout=_SHUTDOWN_TIMEOUT,
3329
)
34-
self._shutdown_event = shutdown_event
3530
await self._lifespan_manager.__aenter__()
36-
completed_event.set()
37-
await self._shutdown_event.wait()
38-
39-
async def shutdown(self):
40-
if self._shutdown_event is not None:
41-
self._shutdown_event.set()
4231

32+
async def on_shutdown(self) -> None:
4333
if self._lifespan_manager is None:
4434
return
4535
await self._lifespan_manager.__aexit__(None, None, None)

0 commit comments

Comments
 (0)