Skip to content

Commit 02d401a

Browse files
fix tests
1 parent 096c3b7 commit 02d401a

File tree

12 files changed

+178
-180
lines changed

12 files changed

+178
-180
lines changed

services/storage/src/simcore_service_storage/api/rpc/_async_jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
)
1616
from servicelib.rabbitmq import RPCRouter
1717

18-
from ...modules.celery._utils import get_celery_client
18+
from ...modules.celery import get_celery_client
1919
from ...modules.celery.models import TaskStatus
2020

2121
router = RPCRouter()

services/storage/src/simcore_service_storage/api/rpc/_data_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from ...datcore_dsm import DatCoreDataManager
1515
from ...dsm import get_dsm_provider
1616
from ...exceptions.errors import FileAccessRightError
17-
from ...modules.celery._utils import get_celery_client
17+
from ...modules.celery import get_celery_client
1818
from ...modules.datcore_adapter.datcore_adapter_exceptions import DatcoreAdapterError
1919
from ...simcore_s3_dsm import SimcoreS3DataManager
2020

services/storage/src/simcore_service_storage/main.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,10 @@
33
import logging
44

55
from servicelib.logging_utils import config_all_loggers
6-
from simcore_service_storage.modules.celery._utils import (
7-
set_celery_app,
8-
set_celery_client,
9-
)
6+
from simcore_service_storage.modules.celery import setup_celery
107

118
from .core.application import create_app
129
from .core.settings import ApplicationSettings
13-
from .modules.celery._common import create_app as create_celery_app
14-
from .modules.celery.client import CeleryTaskQueueClient
1510

1611
_settings = ApplicationSettings.create_from_envs()
1712

@@ -27,9 +22,6 @@
2722
_logger = logging.getLogger(__name__)
2823

2924
fastapi_app = create_app(_settings)
30-
celery_app = create_celery_app(_settings)
31-
32-
set_celery_app(fastapi_app, celery_app)
33-
set_celery_client(fastapi_app, CeleryTaskQueueClient(celery_app))
25+
setup_celery(fastapi_app)
3426

3527
app = fastapi_app
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import logging
2+
from asyncio import AbstractEventLoop
3+
4+
from fastapi import FastAPI
5+
from simcore_service_storage.modules.celery._common import create_app
6+
from simcore_service_storage.modules.celery.client import CeleryTaskQueueClient
7+
8+
from ...core.settings import get_application_settings
9+
10+
_logger = logging.getLogger(__name__)
11+
12+
13+
def setup_celery(app: FastAPI) -> None:
14+
async def on_startup() -> None:
15+
celery_settings = get_application_settings(app).STORAGE_CELERY
16+
assert celery_settings # nosec
17+
celery_app = create_app(celery_settings)
18+
app.state.celery_client = CeleryTaskQueueClient(celery_app)
19+
20+
app.add_event_handler("startup", on_startup)
21+
22+
23+
def get_celery_client(app: FastAPI) -> CeleryTaskQueueClient:
24+
celery_client = app.state.celery_client
25+
assert isinstance(celery_client, CeleryTaskQueueClient)
26+
return celery_client
27+
28+
29+
def get_event_loop(app: FastAPI) -> AbstractEventLoop:
30+
event_loop = app.state.event_loop
31+
assert isinstance(event_loop, AbstractEventLoop)
32+
return event_loop
33+
34+
35+
def set_event_loop(app: FastAPI, event_loop: AbstractEventLoop) -> None:
36+
app.state.event_loop = event_loop

services/storage/src/simcore_service_storage/modules/celery/_common.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
import logging
22

33
from celery import Celery
4+
from settings_library.celery import CelerySettings
45
from settings_library.redis import RedisDatabase
56

6-
from ...core.settings import ApplicationSettings
7-
87
_logger = logging.getLogger(__name__)
98

109

11-
def create_app(settings: ApplicationSettings) -> Celery:
12-
celery_settings = settings.STORAGE_CELERY
10+
def create_app(celery_settings: CelerySettings) -> Celery:
1311
assert celery_settings
1412

1513
app = Celery(

services/storage/src/simcore_service_storage/modules/celery/_utils.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

services/storage/src/simcore_service_storage/modules/celery/client.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@
2222
)
2323
_CELERY_TASK_META_PREFIX: Final[str] = "celery-task-meta-"
2424
_CELERY_TASK_ID_PREFIX: Final[str] = "celery"
25+
_CELERY_STATES_MAPPING: Final[dict[str, TaskState]] = {
26+
"PENDING": TaskState.PENDING,
27+
"STARTED": TaskState.PENDING,
28+
"RUNNING": TaskState.RUNNING,
29+
"SUCCESS": TaskState.SUCCESS,
30+
"ABORTED": TaskState.ABORTED,
31+
"FAILURE": TaskState.FAILURE,
32+
}
2533

2634

2735
def _build_context_prefix(task_context: TaskContext) -> list[str]:
@@ -80,26 +88,29 @@ def _get_progress_report(
8088
) -> ProgressReport:
8189
task_id = _build_task_id(task_context, task_uuid)
8290
result = self._celery_app.AsyncResult(task_id).result
83-
state = self._celery_app.AsyncResult(task_id).state
84-
if result and state == TaskState.RUNNING.value:
91+
state = self._get_state(task_context, task_uuid)
92+
if result and state == TaskState.RUNNING:
8593
with contextlib.suppress(ValidationError):
8694
return ProgressReport.model_validate(result)
8795
if state in (
88-
TaskState.ABORTED.value,
89-
TaskState.FAILURE.value,
90-
TaskState.SUCCESS.value,
96+
TaskState.ABORTED,
97+
TaskState.FAILURE,
98+
TaskState.SUCCESS,
9199
):
92100
return ProgressReport(actual_value=100.0)
93101
return ProgressReport(actual_value=0.0)
94102

103+
def _get_state(self, task_context: TaskContext, task_uuid: TaskUUID) -> TaskState:
104+
task_id = _build_task_id(task_context, task_uuid)
105+
return _CELERY_STATES_MAPPING[self._celery_app.AsyncResult(task_id).state]
106+
95107
@make_async()
96108
def get_task_status(
97109
self, task_context: TaskContext, task_uuid: TaskUUID
98110
) -> TaskStatus:
99-
task_id = _build_task_id(task_context, task_uuid)
100111
return TaskStatus(
101112
task_uuid=task_uuid,
102-
task_state=self._celery_app.AsyncResult(task_id).state,
113+
task_state=self._get_state(task_context, task_uuid),
103114
progress_report=self._get_progress_report(task_context, task_uuid),
104115
)
105116

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from celery import Celery
2+
from fastapi import FastAPI
3+
4+
from .worker import CeleryTaskQueueWorker
5+
6+
_WORKER_KEY = "celery_worker"
7+
_FASTAPI_APP_KEY = "fastapi_app"
8+
9+
10+
def get_celery_worker(celery_app: Celery) -> CeleryTaskQueueWorker:
11+
worker = celery_app.conf[_WORKER_KEY]
12+
assert isinstance(worker, CeleryTaskQueueWorker)
13+
return worker
14+
15+
16+
def get_fastapi_app(celery_app: Celery) -> FastAPI:
17+
fastapi_app = celery_app.conf[_FASTAPI_APP_KEY]
18+
assert isinstance(fastapi_app, FastAPI)
19+
return fastapi_app
20+
21+
22+
def set_celery_worker(celery_app: Celery, worker: CeleryTaskQueueWorker) -> None:
23+
celery_app.conf[_WORKER_KEY] = worker
24+
25+
26+
def set_fastapi_app(celery_app: Celery, fastapi_app: FastAPI) -> None:
27+
celery_app.conf[_FASTAPI_APP_KEY] = fastapi_app

services/storage/src/simcore_service_storage/modules/celery/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from models_library.progress_bar import ProgressReport
77
from servicelib.logging_utils import log_context
88

9-
from .models import TaskID, TaskState
9+
from .models import TaskID
1010

1111
_logger = logging.getLogger(__name__)
1212

@@ -30,6 +30,6 @@ def set_task_progress(
3030
):
3131
self.celery_app.tasks[task_name].update_state(
3232
task_id=task_id,
33-
state=TaskState.RUNNING.value,
33+
state="RUNNING",
3434
meta=report.model_dump(mode="json"),
3535
)

services/storage/src/simcore_service_storage/modules/celery/worker_main.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,25 @@
33
import asyncio
44
import logging
55
import threading
6+
from typing import Final
67

78
from asgi_lifespan import LifespanManager
89
from celery import Celery
910
from celery.signals import worker_init, worker_shutdown
11+
from fastapi import FastAPI
1012
from servicelib.background_task import cancel_wait_task
1113
from servicelib.logging_utils import config_all_loggers
14+
from simcore_service_storage.modules.celery import get_event_loop, set_event_loop
15+
from simcore_service_storage.modules.celery.utils import (
16+
CeleryTaskQueueWorker,
17+
get_fastapi_app,
18+
set_celery_worker,
19+
set_fastapi_app,
20+
)
1221

1322
from ...core.application import create_app
1423
from ...core.settings import ApplicationSettings
15-
from ._utils import create_celery_app_worker
24+
from ._common import create_app as create_celery_app
1625

1726
_settings = ApplicationSettings.create_from_envs()
1827

@@ -26,6 +35,8 @@
2635

2736
_logger = logging.getLogger(__name__)
2837

38+
_LIFESPAN_TIMEOUT: Final[int] = 10
39+
2940

3041
@worker_init.connect
3142
def on_worker_init(sender, **_kwargs):
@@ -39,8 +50,8 @@ def _init_fastapi():
3950
async def lifespan():
4051
async with LifespanManager(
4152
fastapi_app,
42-
startup_timeout=10,
43-
shutdown_timeout=10,
53+
startup_timeout=_LIFESPAN_TIMEOUT,
54+
shutdown_timeout=_LIFESPAN_TIMEOUT,
4455
):
4556
try:
4657
await shutdown_event.wait()
@@ -50,9 +61,10 @@ async def lifespan():
5061
lifespan_task = loop.create_task(lifespan())
5162
fastapi_app.state.lifespan_task = lifespan_task
5263
fastapi_app.state.shutdown_event = shutdown_event
64+
set_event_loop(fastapi_app, loop)
5365

54-
sender.app.conf["fastapi_app"] = fastapi_app
55-
sender.app.conf["loop"] = loop
66+
set_fastapi_app(sender.app, fastapi_app)
67+
set_celery_worker(sender.app, CeleryTaskQueueWorker(sender.app))
5668

5769
loop.run_forever()
5870

@@ -64,15 +76,17 @@ async def lifespan():
6476
def on_worker_shutdown(sender, **_kwargs):
6577
assert isinstance(sender.app, Celery)
6678

67-
loop = sender.app.conf["loop"]
68-
fastapi_app = sender.app.conf["fastapi_app"]
79+
fastapi_app = get_fastapi_app(sender.app)
80+
assert isinstance(fastapi_app, FastAPI)
81+
event_loop = get_event_loop(fastapi_app)
6982

7083
async def shutdown():
7184
fastapi_app.state.shutdown_event.set()
7285

7386
await cancel_wait_task(fastapi_app.state.lifespan_task, max_delay=5)
7487

75-
asyncio.run_coroutine_threadsafe(shutdown(), loop)
88+
asyncio.run_coroutine_threadsafe(shutdown(), event_loop)
7689

7790

78-
app = create_celery_app_worker(_settings)
91+
assert _settings.STORAGE_CELERY
92+
app = create_celery_app(_settings.STORAGE_CELERY)

0 commit comments

Comments
 (0)