Skip to content

Commit f5959cf

Browse files
authored
Merge branch 'master' into feature/search-files
2 parents bf0b98e + 66aebbf commit f5959cf

File tree

37 files changed

+291
-249
lines changed

37 files changed

+291
-249
lines changed

packages/celery-library/src/celery_library/signals.py

Lines changed: 0 additions & 52 deletions
This file was deleted.

packages/celery-library/src/celery_library/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from servicelib.celery.models import TaskKey
1414

1515
from .errors import encode_celery_transferrable_error
16-
from .utils import get_app_server
16+
from .worker.app_server import get_app_server
1717

1818
_logger = logging.getLogger(__name__)
1919

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from collections.abc import Callable
2+
3+
from celery import Celery # type: ignore[import-untyped]
4+
from servicelib.celery.app_server import BaseAppServer
5+
from settings_library.celery import CelerySettings
6+
7+
from ..app import create_app
8+
from .signals import register_worker_signals
9+
10+
11+
def create_worker_app(
12+
settings: CelerySettings,
13+
register_worker_tasks_cb: Callable[[Celery], None],
14+
app_server_factory_cb: Callable[[], BaseAppServer],
15+
) -> Celery:
16+
app = create_app(settings)
17+
register_worker_tasks_cb(app)
18+
register_worker_signals(app, settings, app_server_factory_cb)
19+
20+
return app
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import asyncio
2+
import threading
3+
from collections.abc import Callable
4+
5+
from celery import Celery # type: ignore[import-untyped]
6+
from celery.signals import ( # type: ignore[import-untyped]
7+
worker_init,
8+
worker_process_init,
9+
worker_process_shutdown,
10+
worker_shutdown,
11+
)
12+
from servicelib.celery.app_server import BaseAppServer
13+
from settings_library.celery import CeleryPoolType, CelerySettings
14+
15+
from .app_server import get_app_server, set_app_server
16+
17+
18+
def _worker_init_wrapper(
19+
app: Celery, app_server_factory: Callable[[], BaseAppServer]
20+
) -> Callable[..., None]:
21+
def _worker_init_handler(**_kwargs) -> None:
22+
startup_complete_event = threading.Event()
23+
24+
def _init(startup_complete_event: threading.Event) -> None:
25+
loop = asyncio.new_event_loop()
26+
asyncio.set_event_loop(loop)
27+
28+
app_server = app_server_factory()
29+
app_server.event_loop = loop
30+
31+
set_app_server(app, app_server)
32+
33+
loop.run_until_complete(
34+
app_server.run_until_shutdown(startup_complete_event)
35+
)
36+
37+
thread = threading.Thread(
38+
group=None,
39+
target=_init,
40+
name="app_server_init",
41+
args=(startup_complete_event,),
42+
daemon=True,
43+
)
44+
thread.start()
45+
46+
startup_complete_event.wait()
47+
48+
return _worker_init_handler
49+
50+
51+
def _worker_shutdown_wrapper(app: Celery) -> Callable[..., None]:
52+
def _worker_shutdown_handler(**_kwargs) -> None:
53+
get_app_server(app).shutdown_event.set()
54+
55+
return _worker_shutdown_handler
56+
57+
58+
def register_worker_signals(
59+
app: Celery,
60+
settings: CelerySettings,
61+
app_server_factory: Callable[[], BaseAppServer],
62+
) -> None:
63+
match settings.CELERY_POOL:
64+
case CeleryPoolType.PREFORK:
65+
worker_process_init.connect(
66+
_worker_init_wrapper(app, app_server_factory), weak=False
67+
)
68+
worker_process_shutdown.connect(_worker_shutdown_wrapper(app), weak=False)
69+
case _:
70+
worker_init.connect(_worker_init_wrapper(app, app_server_factory), weak=False)
71+
worker_shutdown.connect(_worker_shutdown_wrapper(app), weak=False)

packages/celery-library/tests/conftest.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,16 @@
1414
start_worker,
1515
)
1616
from celery.signals import worker_init, worker_shutdown
17-
from celery.worker.worker import WorkController
1817
from celery_library.backends.redis import RedisTaskStore
19-
from celery_library.signals import on_worker_init, on_worker_shutdown
2018
from celery_library.task_manager import CeleryTaskManager
2119
from celery_library.types import register_celery_types
20+
from celery_library.worker.signals import _worker_init_wrapper, _worker_shutdown_wrapper
2221
from pytest_simcore.helpers.monkeypatch_envs import setenvs_from_dict
2322
from pytest_simcore.helpers.typing_env import EnvVarsDict
2423
from servicelib.celery.app_server import BaseAppServer
2524
from servicelib.celery.task_manager import TaskManager
2625
from servicelib.redis import RedisClientSDK
27-
from settings_library.celery import CelerySettings
26+
from settings_library.celery import CeleryPoolType, CelerySettings
2827
from settings_library.redis import RedisDatabase, RedisSettings
2928

3029
pytest_plugins = [
@@ -104,11 +103,6 @@ def celery_settings(
104103
return CelerySettings.create_from_envs()
105104

106105

107-
@pytest.fixture
108-
def app_server(celery_app: Celery, celery_settings: CelerySettings) -> BaseAppServer:
109-
return FakeAppServer(app=celery_app, settings=celery_settings)
110-
111-
112106
@pytest.fixture(scope="session")
113107
def celery_config() -> dict[str, Any]:
114108
return {
@@ -128,21 +122,25 @@ def celery_config() -> dict[str, Any]:
128122
@pytest.fixture
129123
async def with_celery_worker(
130124
celery_app: Celery,
131-
app_server: BaseAppServer,
125+
celery_settings: CelerySettings,
132126
register_celery_tasks: Callable[[Celery], None],
133127
) -> AsyncIterator[TestWorkController]:
134-
def _on_worker_init_wrapper(sender: WorkController, **_kwargs):
135-
return on_worker_init(sender, app_server, **_kwargs)
136128

137-
worker_init.connect(_on_worker_init_wrapper)
138-
worker_shutdown.connect(on_worker_shutdown)
129+
def _app_server_factory() -> BaseAppServer:
130+
return FakeAppServer(app=celery_app, settings=celery_settings)
131+
132+
# NOTE: explicitly connect the signals in tests
133+
worker_init.connect(
134+
_worker_init_wrapper(celery_app, _app_server_factory), weak=False
135+
)
136+
worker_shutdown.connect(_worker_shutdown_wrapper(celery_app), weak=False)
139137

140138
register_celery_tasks(celery_app)
141139

142140
with start_worker(
143141
celery_app,
144142
concurrency=1,
145-
pool="threads",
143+
pool=CeleryPoolType.THREADS,
146144
loglevel="info",
147145
perform_ping_check=False,
148146
queues="default",

packages/celery-library/tests/unit/test_task_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from celery_library.errors import TaskNotFoundError, TransferrableCeleryError
1717
from celery_library.task import register_task
1818
from celery_library.task_manager import CeleryTaskManager
19-
from celery_library.utils import get_app_server
19+
from celery_library.worker.app_server import get_app_server
2020
from common_library.errors_classes import OsparcErrorMixin
2121
from faker import Faker
2222
from models_library.progress_bar import ProgressReport

packages/service-library/src/servicelib/fastapi/celery/app_server.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from ...celery.app_server import BaseAppServer
1010
from ...celery.task_manager import TaskManager
1111

12+
_STARTUP_TIMEOUT: Final[float] = datetime.timedelta(minutes=5).total_seconds()
1213
_SHUTDOWN_TIMEOUT: Final[float] = datetime.timedelta(seconds=10).total_seconds()
1314

1415
_logger = logging.getLogger(__name__)
@@ -27,9 +28,10 @@ async def run_until_shutdown(
2728
) -> None:
2829
async with LifespanManager(
2930
self.app,
30-
startup_timeout=None, # waits for full app initialization (DB migrations, etc.)
31+
startup_timeout=_STARTUP_TIMEOUT,
3132
shutdown_timeout=_SHUTDOWN_TIMEOUT,
3233
):
33-
_logger.info("fastapi app initialized")
34+
_logger.info("FastAPI initialized: %s", self.app)
3435
startup_completed_event.set()
3536
await self.shutdown_event.wait() # NOTE: wait here until shutdown is requested
37+
_logger.info("FastAPI shutdown completed: %s", self.app)

0 commit comments

Comments
 (0)