Skip to content

Commit bec30ba

Browse files
Acquire a lock before running initialization steps (#768)
### Description If multiple workers are detected, and a Redis client, Hyperion will raise an error. Otherwise the server won't be able to: - run initialization steps - send messages over WebSocket reliably. The logic is extracted from the great #375 --------- Co-authored-by: Thonyk <[email protected]> Co-authored-by: Timothée Robert <[email protected]>
1 parent c588fc8 commit bec30ba

File tree

18 files changed

+264
-173
lines changed

18 files changed

+264
-173
lines changed

Dockerfile

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,4 @@ COPY migrations migrations/
1616
COPY assets assets/
1717
COPY app app/
1818

19-
COPY start.sh .
20-
RUN chmod +x start.sh
21-
22-
ENTRYPOINT ["./start.sh"]
19+
ENTRYPOINT ["fastapi", "run", "app/main.py", "--workers", "${NB_WORKERS:-1}"]

app/api.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,10 @@
22

33
from fastapi import APIRouter
44

5-
from app.core.core_module_list import core_module_list
6-
from app.modules.module_list import module_list
5+
from app.module import all_modules
76

87
api_router = APIRouter()
98

109

11-
for core_module in core_module_list:
10+
for core_module in all_modules:
1211
api_router.include_router(core_module.router)
13-
14-
for module in module_list:
15-
api_router.include_router(module.router)

app/app.py

Lines changed: 133 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717
from fastapi.middleware.cors import CORSMiddleware
1818
from fastapi.responses import JSONResponse
1919
from fastapi.routing import APIRoute
20+
from redis import Redis
2021
from sqlalchemy.engine import Connection, Engine
2122
from sqlalchemy.exc import IntegrityError
23+
from sqlalchemy.ext.asyncio import AsyncSession
2224
from sqlalchemy.orm import Session
2325

2426
from app import api
@@ -36,18 +38,23 @@
3638
get_scheduler,
3739
get_websocket_connection_manager,
3840
init_and_get_db_engine,
41+
init_websocket_connection_manager,
42+
)
43+
from app.module import module_list
44+
from app.types.exceptions import (
45+
ContentHTTPException,
46+
GoogleAPIInvalidCredentialsError,
47+
MultipleWorkersWithoutRedisInitializationError,
3948
)
40-
from app.modules.module_list import module_list
41-
from app.types.exceptions import ContentHTTPException, GoogleAPIInvalidCredentialsError
4249
from app.types.sqlalchemy import Base
4350
from app.utils import initialization
4451
from app.utils.redis import limiter
4552

4653
if TYPE_CHECKING:
4754
import redis
4855

49-
from app.types.scheduler import Scheduler
50-
from app.types.websocket import WebsocketConnectionManager
56+
from app.types.scheduler import Scheduler
57+
from app.types.websocket import WebsocketConnectionManager
5158

5259
# NOTE: We can not get loggers at the top of this file like we do in other files
5360
# as the loggers are not yet initialized
@@ -330,6 +337,10 @@ def init_db(
330337
sync_engine=sync_engine,
331338
hyperion_error_logger=hyperion_error_logger,
332339
)
340+
341+
# TODO: we may allow the following steps to be run by other workers
342+
# and may not need to wait for them
343+
# These two steps could use an async database connection
333344
initialize_schools(
334345
sync_engine=sync_engine,
335346
hyperion_error_logger=hyperion_error_logger,
@@ -340,15 +351,30 @@ def init_db(
340351
)
341352

342353

343-
# We wrap the application in a function to be able to pass the settings and drop_db parameters
344-
# The drop_db parameter is used to drop the database tables before creating them again
345-
def get_application(settings: Settings, drop_db: bool = False) -> FastAPI:
346-
# Initialize loggers
347-
LogConfig().initialize_loggers(settings=settings)
354+
async def init_google_API(
355+
db: AsyncSession,
356+
settings: Settings,
357+
) -> None:
358+
# Init Google API credentials
348359

349-
hyperion_access_logger = logging.getLogger("hyperion.access")
350-
hyperion_security_logger = logging.getLogger("hyperion.security")
351-
hyperion_error_logger = logging.getLogger("hyperion.error")
360+
google_api = GoogleAPI()
361+
362+
if google_api.is_google_api_configured(settings):
363+
try:
364+
await google_api.get_credentials(db, settings)
365+
366+
except GoogleAPIInvalidCredentialsError:
367+
# We expect this error to be raised if the credentials were never set before
368+
pass
369+
370+
371+
def test_configuration(
372+
settings: Settings,
373+
hyperion_error_logger: logging.Logger,
374+
) -> None:
375+
"""
376+
Test configuration and log warnings if some settings are not configured correctly.
377+
"""
352378

353379
# We use warning level so that the message is not sent to matrix again
354380
if not settings.MATRIX_TOKEN:
@@ -369,42 +395,108 @@ def get_application(settings: Settings, drop_db: bool = False) -> FastAPI:
369395
Path("data/ics/").mkdir(parents=True, exist_ok=True)
370396
Path("data/core/").mkdir(parents=True, exist_ok=True)
371397

398+
399+
async def init_lifespan(
400+
app: FastAPI,
401+
settings: Settings,
402+
hyperion_error_logger: logging.Logger,
403+
drop_db: bool,
404+
) -> tuple[Scheduler, WebsocketConnectionManager]:
405+
hyperion_error_logger.info("Startup: Initializing application")
406+
407+
# Init the Redis client
408+
redis_client: redis.Redis | bool | None = app.dependency_overrides.get(
409+
get_redis_client,
410+
get_redis_client,
411+
)(settings=settings)
412+
413+
# Initialization steps should only be run once across all workers
414+
# We use Redis locks to ensure that the initialization steps are only run once
415+
if initialization.get_number_of_workers() > 1 and not isinstance(
416+
redis_client,
417+
Redis,
418+
):
419+
raise MultipleWorkersWithoutRedisInitializationError
420+
421+
# We need to run the database initialization only once across all the workers
422+
# Other workers have to wait for the db to be initialized
423+
await initialization.use_lock_for_workers(
424+
init_db,
425+
"init_db",
426+
redis_client,
427+
hyperion_error_logger,
428+
unlock_key="db_initialized",
429+
settings=settings,
430+
hyperion_error_logger=hyperion_error_logger,
431+
drop_db=drop_db,
432+
)
433+
434+
await initialization.use_lock_for_workers(
435+
test_configuration,
436+
"test_configuration",
437+
redis_client,
438+
hyperion_error_logger,
439+
settings=settings,
440+
hyperion_error_logger=hyperion_error_logger,
441+
)
442+
443+
async for db in app.dependency_overrides.get(
444+
get_db,
445+
get_db,
446+
)():
447+
await initialization.use_lock_for_workers(
448+
init_google_API,
449+
"init_google_API",
450+
redis_client,
451+
hyperion_error_logger,
452+
db=db,
453+
settings=settings,
454+
)
455+
456+
init_websocket_connection_manager(WebsocketConnectionManager(settings=settings))
457+
ws_manager: WebsocketConnectionManager = app.dependency_overrides.get(
458+
get_websocket_connection_manager,
459+
get_websocket_connection_manager,
460+
)(settings=settings)
461+
462+
arq_scheduler: Scheduler = app.dependency_overrides.get(
463+
get_scheduler,
464+
get_scheduler,
465+
)(settings=settings)
466+
467+
await ws_manager.connect_broadcaster()
468+
await arq_scheduler.start(
469+
redis_host=settings.REDIS_HOST,
470+
redis_port=settings.REDIS_PORT,
471+
redis_password=settings.REDIS_PASSWORD,
472+
_dependency_overrides=app.dependency_overrides,
473+
)
474+
return arq_scheduler, ws_manager
475+
476+
477+
# We wrap the application in a function to be able to pass the settings and drop_db parameters
478+
# The drop_db parameter is used to drop the database tables before creating them again
479+
def get_application(settings: Settings, drop_db: bool = False) -> FastAPI:
480+
# Initialize loggers
481+
LogConfig().initialize_loggers(settings=settings)
482+
483+
hyperion_access_logger = logging.getLogger("hyperion.access")
484+
hyperion_security_logger = logging.getLogger("hyperion.security")
485+
hyperion_error_logger = logging.getLogger("hyperion.error")
486+
372487
# Creating a lifespan which will be called when the application starts then shuts down
373488
# https://fastapi.tiangolo.com/advanced/events/
374489
@asynccontextmanager
375490
async def lifespan(app: FastAPI) -> AsyncGenerator:
376-
# Init Google API credentials
377-
google_api = GoogleAPI()
378-
if google_api.is_google_api_configured(settings):
379-
async for db in app.dependency_overrides.get(
380-
get_db,
381-
get_db,
382-
)():
383-
try:
384-
await google_api.get_credentials(db, settings)
385-
except GoogleAPIInvalidCredentialsError:
386-
# We expect this error to be raised if the credentials were never set before
387-
pass
388-
389-
ws_manager: WebsocketConnectionManager = app.dependency_overrides.get(
390-
get_websocket_connection_manager,
391-
get_websocket_connection_manager,
392-
)(settings=settings)
393-
394-
arq_scheduler: Scheduler = app.dependency_overrides.get(
395-
get_scheduler,
396-
get_scheduler,
397-
)(settings=settings)
398-
399-
await ws_manager.connect_broadcaster()
400-
await arq_scheduler.start(
401-
redis_host=settings.REDIS_HOST,
402-
redis_port=settings.REDIS_PORT,
403-
redis_password=settings.REDIS_PASSWORD,
404-
_dependency_overrides=app.dependency_overrides,
491+
arq_scheduler, ws_manager = await init_lifespan(
492+
app=app,
493+
settings=settings,
494+
hyperion_error_logger=hyperion_error_logger,
495+
drop_db=drop_db,
405496
)
406497

407498
yield
499+
408500
hyperion_error_logger.info("Shutting down")
409501
await arq_scheduler.close()
410502
await ws_manager.disconnect_broadcaster()
@@ -429,21 +521,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
429521
calypsso = get_calypsso_app()
430522
app.mount("/calypsso", calypsso, "Calypsso")
431523

432-
if settings.HYPERION_INIT_DB:
433-
init_db(
434-
settings=settings,
435-
hyperion_error_logger=hyperion_error_logger,
436-
drop_db=drop_db,
437-
)
438-
else:
439-
hyperion_error_logger.info("Database initialization skipped")
440-
441-
# Initialize Redis
442-
if not app.dependency_overrides.get(get_redis_client, get_redis_client)(
443-
settings=settings,
444-
):
445-
hyperion_error_logger.info("Redis client not configured")
446-
447524
# We need to init the database engine to be able to use it in dependencies
448525
init_and_get_db_engine(settings)
449526

app/core/core_endpoints/endpoints_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
is_user,
1717
is_user_in,
1818
)
19-
from app.modules.module_list import module_list
19+
from app.module import module_list
2020
from app.types.module import CoreModule
2121
from app.utils.tools import is_group_id_valid
2222

app/core/core_module_list.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

app/core/payment/endpoints_payment.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
from pydantic import TypeAdapter, ValidationError
1010
from sqlalchemy.ext.asyncio import AsyncSession
1111

12-
from app.core.core_module_list import core_module_list
1312
from app.core.payment import cruds_payment, models_payment, schemas_payment
1413
from app.core.payment.types_payment import (
1514
NotificationResultContent,
1615
)
1716
from app.dependencies import get_db
18-
from app.modules.module_list import module_list
17+
from app.module import all_modules
1918
from app.types.module import CoreModule
2019

2120
router = APIRouter(tags=["Payments"])
@@ -133,7 +132,7 @@ async def webhook(
133132

134133
# If a callback is defined for the module, we want to call it
135134
try:
136-
for module in module_list + core_module_list:
135+
for module in all_modules:
137136
if module.root == checkout.module:
138137
if module.payment_callback is not None:
139138
hyperion_error_logger.info(

app/core/utils/config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,6 @@ def MINIMAL_TITAN_VERSION_CODE(cls) -> str:
309309
# Automatically generated parameters #
310310
######################################
311311

312-
# If Hyperion should initialize the database on startup
313-
# This environment variable is set by our init Python file to tell the workers to avoid initializing the database
314-
# You don't want to set this variable manually
315-
HYPERION_INIT_DB: bool = True
316-
317312
# The following properties can not be instantiated as class variables as them need to be computed using another property from the class,
318313
# which won't be available before the .env file parsing.
319314
# We thus decide to use the decorator `@property` to make these methods usable as properties and not functions: as properties: Settings.RSA_PRIVATE_KEY, Settings.RSA_PUBLIC_KEY and Settings.RSA_PUBLIC_JWK

app/dependencies.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,14 +205,16 @@ def get_scheduler(settings: Settings = Depends(get_settings)) -> Scheduler:
205205
return scheduler
206206

207207

208-
def get_websocket_connection_manager(
209-
settings: Settings = Depends(get_settings),
210-
):
208+
def init_websocket_connection_manager(
209+
wscm: WebsocketConnectionManager,
210+
) -> None:
211211
global websocket_connection_manager
212+
websocket_connection_manager = wscm
212213

213-
if websocket_connection_manager is None:
214-
websocket_connection_manager = WebsocketConnectionManager(settings=settings)
215214

215+
def get_websocket_connection_manager(
216+
settings: Settings = Depends(get_settings),
217+
):
216218
return websocket_connection_manager
217219

218220

app/module.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22
import logging
33
from pathlib import Path
44

5-
from app.types.module import Module
5+
from app.types.module import CoreModule, Module
66

77
hyperion_error_logger = logging.getLogger("hyperion.error")
88

99
module_list: list[Module] = []
10+
core_module_list: list[CoreModule] = []
11+
all_modules: list[CoreModule] = []
1012

1113
for endpoints_file in Path().glob("app/modules/*/endpoints_*.py"):
1214
endpoint_module = importlib.import_module(
@@ -19,3 +21,19 @@
1921
hyperion_error_logger.error(
2022
f"Module {endpoints_file} does not declare a module. It won't be enabled.",
2123
)
24+
25+
for endpoints_file in Path().glob("app/core/*/endpoints_*.py"):
26+
endpoint_module = importlib.import_module(
27+
".".join(endpoints_file.with_suffix("").parts),
28+
)
29+
30+
if hasattr(endpoint_module, "core_module"):
31+
core_module: CoreModule = endpoint_module.core_module
32+
core_module_list.append(core_module)
33+
else:
34+
hyperion_error_logger.error(
35+
f"Core module {endpoints_file} does not declare a core module. It won't be enabled.",
36+
)
37+
38+
39+
all_modules = module_list + core_module_list

0 commit comments

Comments
 (0)