Skip to content

Commit c08faf9

Browse files
feat: wrap app server
1 parent cb91384 commit c08faf9

File tree

10 files changed

+117
-87
lines changed

10 files changed

+117
-87
lines changed

packages/celery-library/src/celery_library/__init__.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
from asyncio import AbstractEventLoop
32

43
from fastapi import FastAPI
54
from settings_library.celery import CelerySettings
@@ -27,13 +26,3 @@ def get_celery_client(app: FastAPI) -> CeleryTaskManager:
2726
celery_client = app.state.celery_client
2827
assert isinstance(celery_client, CeleryTaskManager)
2928
return celery_client
30-
31-
32-
def get_event_loop(app: FastAPI) -> AbstractEventLoop:
33-
event_loop = app.state.event_loop
34-
assert isinstance(event_loop, AbstractEventLoop)
35-
return event_loop
36-
37-
38-
def set_event_loop(app: FastAPI, event_loop: AbstractEventLoop) -> None:
39-
app.state.event_loop = event_loop
Lines changed: 14 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,39 @@
11
import asyncio
2-
import datetime
32
import logging
43
import threading
5-
from collections.abc import Callable
6-
from typing import Final
74

8-
from asgi_lifespan import LifespanManager
95
from celery import Celery # type: ignore[import-untyped]
106
from celery.worker.worker import WorkController # type: ignore[import-untyped]
11-
from fastapi import FastAPI
7+
from servicelib.base_app_server import BaseAppServer
128
from servicelib.logging_utils import log_context
139
from settings_library.celery import CelerySettings
1410

15-
from . import set_event_loop
1611
from .common import create_task_manager
1712
from .utils import (
18-
get_fastapi_app,
19-
set_fastapi_app,
13+
get_app_server,
14+
set_app_server,
2015
set_task_manager,
2116
)
2217

2318
_logger = logging.getLogger(__name__)
2419

25-
_SHUTDOWN_TIMEOUT: Final[float] = datetime.timedelta(seconds=10).total_seconds()
26-
_STARTUP_TIMEOUT: Final[float] = datetime.timedelta(minutes=1).total_seconds()
27-
2820

2921
def on_worker_init(
30-
app_factory: Callable[[], FastAPI],
22+
app_server: BaseAppServer,
3123
celery_settings: CelerySettings,
3224
sender: WorkController,
3325
**_kwargs,
3426
) -> None:
35-
startup_complete_event = threading.Event()
36-
37-
def _init(startup_complete_event: threading.Event) -> None:
27+
def _init() -> None:
3828
loop = asyncio.new_event_loop()
3929
asyncio.set_event_loop(loop)
40-
shutdown_event = asyncio.Event()
4130

42-
fastapi_app = app_factory()
43-
assert isinstance(fastapi_app, FastAPI) # nosec
31+
app_server.event_loop = loop
4432

4533
async def setup_task_manager():
4634
assert sender.app # nosec
4735
assert isinstance(sender.app, Celery) # nosec
36+
4837
set_task_manager(
4938
sender.app,
5039
create_task_manager(
@@ -53,45 +42,22 @@ async def setup_task_manager():
5342
),
5443
)
5544

56-
async def fastapi_lifespan(
57-
startup_complete_event: threading.Event, shutdown_event: asyncio.Event
58-
) -> None:
59-
async with LifespanManager(
60-
fastapi_app,
61-
startup_timeout=_STARTUP_TIMEOUT,
62-
shutdown_timeout=_SHUTDOWN_TIMEOUT,
63-
):
64-
try:
65-
_logger.info("fastapi APP started!")
66-
startup_complete_event.set()
67-
await shutdown_event.wait()
68-
except asyncio.CancelledError:
69-
_logger.warning("Lifespan task cancelled")
70-
71-
fastapi_app.state.shutdown_event = shutdown_event
72-
set_event_loop(fastapi_app, loop)
73-
74-
set_fastapi_app(sender.app, fastapi_app)
45+
set_app_server(sender.app, app_server)
7546
loop.run_until_complete(setup_task_manager())
76-
loop.run_until_complete(
77-
fastapi_lifespan(startup_complete_event, shutdown_event)
78-
)
47+
loop.run_until_complete(app_server.startup())
7948

8049
thread = threading.Thread(
8150
group=None,
8251
target=_init,
83-
name="fastapi_app",
84-
args=(startup_complete_event,),
52+
name="app_server_init",
8553
daemon=True,
8654
)
8755
thread.start()
88-
# ensure the fastapi app is ready before going on
89-
startup_complete_event.wait(_STARTUP_TIMEOUT * 1.1)
9056

9157

9258
def on_worker_shutdown(sender, **_kwargs) -> None:
93-
with log_context(_logger, logging.INFO, "Worker Shuts-down"):
59+
with log_context(_logger, logging.INFO, "Worker shutdown"):
9460
assert isinstance(sender.app, Celery)
95-
fastapi_app = get_fastapi_app(sender.app)
96-
assert isinstance(fastapi_app, FastAPI)
97-
fastapi_app.state.shutdown_event.set()
61+
app_server = get_app_server(sender.app)
62+
63+
app_server.event_loop.run_until_complete(app_server.shutdown())

packages/celery-library/src/celery_library/task.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,9 @@
1515
from pydantic import NonNegativeInt
1616
from servicelib.async_utils import cancel_wait_task
1717

18-
from . import get_event_loop
1918
from .errors import encode_celery_transferrable_error
2019
from .models import TaskID
21-
from .utils import get_fastapi_app
20+
from .utils import get_app_server
2221

2322
_logger = logging.getLogger(__name__)
2423

@@ -48,7 +47,7 @@ def decorator(
4847
) -> Callable[Concatenate[AbortableTask, P], R]:
4948
@wraps(coro)
5049
def wrapper(task: AbortableTask, *args: P.args, **kwargs: P.kwargs) -> R:
51-
fastapi_app = get_fastapi_app(app)
50+
app_server = get_app_server(app)
5251
# NOTE: task.request is a thread local object, so we need to pass the id explicitly
5352
assert task.request.id is not None # nosec
5453

@@ -90,7 +89,7 @@ async def abort_monitor():
9089

9190
return asyncio.run_coroutine_threadsafe(
9291
run_task(task.request.id),
93-
get_event_loop(fastapi_app),
92+
app_server.event_loop,
9493
).result()
9594

9695
return wrapper
Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
from celery import Celery # type: ignore[import-untyped]
2-
from fastapi import FastAPI
2+
from servicelib.base_app_server import BaseAppServer
33

44
from .task_manager import CeleryTaskManager
55

6-
_WORKER_KEY = "celery_worker"
7-
_FASTAPI_APP_KEY = "fastapi_app"
6+
_APP_SERVER_KEY = "app_server"
7+
_TASK_MANAGER_KEY = "task_manager"
88

99

10-
def set_task_manager(celery_app: Celery, worker: CeleryTaskManager) -> None:
11-
celery_app.conf[_WORKER_KEY] = worker
10+
def get_app_server(app: Celery) -> BaseAppServer:
11+
app_server = app.conf[_APP_SERVER_KEY]
12+
assert isinstance(app_server, BaseAppServer)
13+
return app_server
14+
15+
16+
def set_app_server(app: Celery, app_server: BaseAppServer) -> None:
17+
app.conf[_APP_SERVER_KEY] = app_server
1218

1319

1420
def get_task_manager(celery_app: Celery) -> CeleryTaskManager:
15-
worker = celery_app.conf[_WORKER_KEY]
21+
worker = celery_app.conf[_TASK_MANAGER_KEY]
1622
assert isinstance(worker, CeleryTaskManager)
1723
return worker
1824

1925

20-
def set_fastapi_app(celery_app: Celery, fastapi_app: FastAPI) -> None:
21-
celery_app.conf[_FASTAPI_APP_KEY] = fastapi_app
22-
23-
24-
def get_fastapi_app(celery_app: Celery) -> FastAPI:
25-
fastapi_app = celery_app.conf[_FASTAPI_APP_KEY]
26-
assert isinstance(fastapi_app, FastAPI)
27-
return fastapi_app
26+
def set_task_manager(celery_app: Celery, worker: CeleryTaskManager) -> None:
27+
celery_app.conf[_TASK_MANAGER_KEY] = worker
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from abc import ABC, abstractmethod
2+
from asyncio import AbstractEventLoop
3+
from contextlib import suppress
4+
from typing import TYPE_CHECKING
5+
6+
if TYPE_CHECKING:
7+
with suppress(ImportError):
8+
from fastapi import FastAPI
9+
with suppress(ImportError):
10+
from aiohttp.web import Application
11+
12+
13+
class BaseAppServer(ABC):
14+
@property
15+
def fastapi_app(self) -> "FastAPI":
16+
raise NotImplementedError
17+
18+
@property
19+
def aiohttp_app(self) -> "Application":
20+
raise NotImplementedError
21+
22+
@abstractmethod
23+
async def startup(self):
24+
pass
25+
26+
@property
27+
def event_loop(self) -> AbstractEventLoop:
28+
return self._event_loop
29+
30+
@event_loop.setter
31+
def event_loop(self, loop: AbstractEventLoop) -> None:
32+
self._event_loop = loop
33+
34+
@abstractmethod
35+
async def shutdown(self):
36+
pass
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import asyncio
2+
from datetime import timedelta
3+
from typing import Final
4+
5+
from asgi_lifespan import LifespanManager
6+
from fastapi import FastAPI
7+
8+
from ..base_app_server import BaseAppServer
9+
10+
_SHUTDOWN_TIMEOUT: Final[float] = timedelta(seconds=10).total_seconds()
11+
_STARTUP_TIMEOUT: Final[float] = timedelta(minutes=1).total_seconds()
12+
13+
14+
class FastAPIAppServer(BaseAppServer):
15+
def __init__(self, app: FastAPI):
16+
self._app = app
17+
self._lifespan_manager = None
18+
self._shutdown_event = asyncio.Event()
19+
20+
@property
21+
def fastapi_app(self) -> FastAPI:
22+
assert isinstance(self._app, FastAPI) # nosec
23+
return self._app
24+
25+
async def startup(self):
26+
self._lifespan_manager = LifespanManager(
27+
self.fastapi_app,
28+
startup_timeout=_STARTUP_TIMEOUT,
29+
shutdown_timeout=_SHUTDOWN_TIMEOUT,
30+
)
31+
await self._lifespan_manager.__aenter__()
32+
await self._shutdown_event.wait()
33+
34+
async def shutdown(self):
35+
self._shutdown_event.set()
36+
if self._lifespan_manager is None:
37+
return
38+
await self._lifespan_manager.__aexit__(None, None, None)

services/storage/src/simcore_service_storage/api/_worker_tasks/_files.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from celery import Task # type: ignore[import-untyped]
44
from celery_library.models import TaskID
5-
from celery_library.utils import get_fastapi_app
5+
from celery_library.utils import get_app_server
66
from models_library.api_schemas_storage.storage_schemas import (
77
FileUploadCompletionBody,
88
)
@@ -30,7 +30,7 @@ async def complete_upload_file(
3030
logging.INFO,
3131
msg=f"completing upload of file {user_id=}, {location_id=}, {file_id=}",
3232
):
33-
dsm = get_dsm_provider(get_fastapi_app(task.app)).get(location_id)
33+
dsm = get_dsm_provider(get_app_server(task.app).fastapi_app).get(location_id)
3434
# NOTE: completing a multipart upload on AWS can take up to several minutes
3535
# if it returns slow we return a 202 - Accepted, the client will have to check later
3636
# for completeness

services/storage/src/simcore_service_storage/api/_worker_tasks/_paths.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from celery import Task # type: ignore[import-untyped]
55
from celery_library.models import TaskID
6-
from celery_library.utils import get_fastapi_app
6+
from celery_library.utils import get_app_server
77
from models_library.projects_nodes_io import LocationID, StorageFileID
88
from models_library.users import UserID
99
from pydantic import ByteSize, TypeAdapter
@@ -25,7 +25,7 @@ async def compute_path_size(
2525
logging.INFO,
2626
msg=f"computing path size {user_id=}, {location_id=}, {path=}",
2727
):
28-
dsm = get_dsm_provider(get_fastapi_app(task.app)).get(location_id)
28+
dsm = get_dsm_provider(get_app_server(task.app).fastapi_app).get(location_id)
2929
return await dsm.compute_path_size(user_id, path=Path(path))
3030

3131

@@ -42,7 +42,7 @@ async def delete_paths(
4242
logging.INFO,
4343
msg=f"delete {paths=} in {location_id=} for {user_id=}",
4444
):
45-
dsm = get_dsm_provider(get_fastapi_app(task.app)).get(location_id)
45+
dsm = get_dsm_provider(get_app_server(task.app).fastapi_app).get(location_id)
4646
files_ids: set[StorageFileID] = {
4747
TypeAdapter(StorageFileID).validate_python(f"{path}") for path in paths
4848
}

services/storage/src/simcore_service_storage/api/_worker_tasks/_simcore_s3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from aws_library.s3._models import S3ObjectKey
66
from celery import Task # type: ignore[import-untyped]
77
from celery_library.models import TaskID
8-
from celery_library.utils import get_fastapi_app, get_task_manager
8+
from celery_library.utils import get_app_server, get_task_manager
99
from models_library.api_schemas_storage.storage_schemas import FoldersBody
1010
from models_library.api_schemas_webserver.storage import PathToExport
1111
from models_library.progress_bar import ProgressReport
@@ -40,7 +40,7 @@ async def deep_copy_files_from_project(
4040
logging.INFO,
4141
msg=f"copying {body.source['uuid']} -> {body.destination['uuid']} with {task.request.id}",
4242
):
43-
dsm = get_dsm_provider(get_fastapi_app(task.app)).get(
43+
dsm = get_dsm_provider(get_app_server(task.app).fastapi_app).get(
4444
SimcoreS3DataManager.get_location_id()
4545
)
4646
assert isinstance(dsm, SimcoreS3DataManager) # nosec
@@ -75,7 +75,7 @@ async def export_data(
7575
logging.INFO,
7676
f"'{task_id}' export data (for {user_id=}) fom selection: {paths_to_export}",
7777
):
78-
dsm = get_dsm_provider(get_fastapi_app(task.app)).get(
78+
dsm = get_dsm_provider(get_app_server(task.app).fastapi_app).get(
7979
SimcoreS3DataManager.get_location_id()
8080
)
8181
assert isinstance(dsm, SimcoreS3DataManager) # nosec

services/storage/src/simcore_service_storage/modules/celery/worker_main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
on_worker_init,
1010
on_worker_shutdown,
1111
)
12+
from servicelib.fastapi.app_server import FastAPIAppServer
1213
from servicelib.logging_utils import config_all_loggers
1314
from simcore_service_storage.api._worker_tasks.tasks import setup_worker_tasks
1415

@@ -28,11 +29,12 @@
2829

2930
assert _settings.STORAGE_CELERY # nosec
3031
app = create_celery_app(_settings.STORAGE_CELERY)
31-
app_factory = partial(create_app, _settings)
32+
33+
app_server = FastAPIAppServer(app=create_app(_settings))
3234

3335

3436
def worker_init_wrapper(sender, **_kwargs):
35-
return partial(on_worker_init, app_factory, _settings.STORAGE_CELERY)(
37+
return partial(on_worker_init, app_server, _settings.STORAGE_CELERY)(
3638
sender, **_kwargs
3739
)
3840

0 commit comments

Comments
 (0)