1717from fastapi .middleware .cors import CORSMiddleware
1818from fastapi .responses import JSONResponse
1919from fastapi .routing import APIRoute
20+ from redis import Redis
2021from sqlalchemy .engine import Connection , Engine
2122from sqlalchemy .exc import IntegrityError
23+ from sqlalchemy .ext .asyncio import AsyncSession
2224from sqlalchemy .orm import Session
2325
2426from app import api
3638 get_scheduler ,
3739 get_websocket_connection_manager ,
3840 init_and_get_db_engine ,
41+ init_websocket_connection_manager ,
42+ )
43+ from app .module import module_list
44+ from app .types .exceptions import (
45+ ContentHTTPException ,
46+ GoogleAPIInvalidCredentialsError ,
47+ MultipleWorkersWithoutRedisInitializationError ,
3948)
40- from app .modules .module_list import module_list
41- from app .types .exceptions import ContentHTTPException , GoogleAPIInvalidCredentialsError
4249from app .types .sqlalchemy import Base
4350from app .utils import initialization
4451from app .utils .redis import limiter
4552
4653if TYPE_CHECKING :
4754 import redis
4855
49- from app .types .scheduler import Scheduler
50- from app .types .websocket import WebsocketConnectionManager
56+ from app .types .scheduler import Scheduler
57+ from app .types .websocket import WebsocketConnectionManager
5158
5259# NOTE: We can not get loggers at the top of this file like we do in other files
5360# as the loggers are not yet initialized
@@ -330,6 +337,10 @@ def init_db(
330337 sync_engine = sync_engine ,
331338 hyperion_error_logger = hyperion_error_logger ,
332339 )
340+
341+ # TODO: we may allow the following steps to be run by other workers
342+ # and may not need to wait for them
343+ # These two steps could use an async database connection
333344 initialize_schools (
334345 sync_engine = sync_engine ,
335346 hyperion_error_logger = hyperion_error_logger ,
@@ -340,15 +351,30 @@ def init_db(
340351 )
341352
342353
343- # We wrap the application in a function to be able to pass the settings and drop_db parameters
344- # The drop_db parameter is used to drop the database tables before creating them again
345- def get_application ( settings : Settings , drop_db : bool = False ) -> FastAPI :
346- # Initialize loggers
347- LogConfig (). initialize_loggers ( settings = settings )
354+ async def init_google_API (
355+ db : AsyncSession ,
356+ settings : Settings ,
357+ ) -> None :
358+ # Init Google API credentials
348359
349- hyperion_access_logger = logging .getLogger ("hyperion.access" )
350- hyperion_security_logger = logging .getLogger ("hyperion.security" )
351- hyperion_error_logger = logging .getLogger ("hyperion.error" )
360+ google_api = GoogleAPI ()
361+
362+ if google_api .is_google_api_configured (settings ):
363+ try :
364+ await google_api .get_credentials (db , settings )
365+
366+ except GoogleAPIInvalidCredentialsError :
367+ # We expect this error to be raised if the credentials were never set before
368+ pass
369+
370+
371+ def test_configuration (
372+ settings : Settings ,
373+ hyperion_error_logger : logging .Logger ,
374+ ) -> None :
375+ """
376+ Test configuration and log warnings if some settings are not configured correctly.
377+ """
352378
353379 # We use warning level so that the message is not sent to matrix again
354380 if not settings .MATRIX_TOKEN :
@@ -369,42 +395,108 @@ def get_application(settings: Settings, drop_db: bool = False) -> FastAPI:
369395 Path ("data/ics/" ).mkdir (parents = True , exist_ok = True )
370396 Path ("data/core/" ).mkdir (parents = True , exist_ok = True )
371397
398+
399+ async def init_lifespan (
400+ app : FastAPI ,
401+ settings : Settings ,
402+ hyperion_error_logger : logging .Logger ,
403+ drop_db : bool ,
404+ ) -> tuple [Scheduler , WebsocketConnectionManager ]:
405+ hyperion_error_logger .info ("Startup: Initializing application" )
406+
407+ # Init the Redis client
408+ redis_client : redis .Redis | bool | None = app .dependency_overrides .get (
409+ get_redis_client ,
410+ get_redis_client ,
411+ )(settings = settings )
412+
413+ # Initialization steps should only be run once across all workers
414+ # We use Redis locks to ensure that the initialization steps are only run once
415+ if initialization .get_number_of_workers () > 1 and not isinstance (
416+ redis_client ,
417+ Redis ,
418+ ):
419+ raise MultipleWorkersWithoutRedisInitializationError
420+
421+ # We need to run the database initialization only once across all the workers
422+ # Other workers have to wait for the db to be initialized
423+ await initialization .use_lock_for_workers (
424+ init_db ,
425+ "init_db" ,
426+ redis_client ,
427+ hyperion_error_logger ,
428+ unlock_key = "db_initialized" ,
429+ settings = settings ,
430+ hyperion_error_logger = hyperion_error_logger ,
431+ drop_db = drop_db ,
432+ )
433+
434+ await initialization .use_lock_for_workers (
435+ test_configuration ,
436+ "test_configuration" ,
437+ redis_client ,
438+ hyperion_error_logger ,
439+ settings = settings ,
440+ hyperion_error_logger = hyperion_error_logger ,
441+ )
442+
443+ async for db in app .dependency_overrides .get (
444+ get_db ,
445+ get_db ,
446+ )():
447+ await initialization .use_lock_for_workers (
448+ init_google_API ,
449+ "init_google_API" ,
450+ redis_client ,
451+ hyperion_error_logger ,
452+ db = db ,
453+ settings = settings ,
454+ )
455+
456+ init_websocket_connection_manager (WebsocketConnectionManager (settings = settings ))
457+ ws_manager : WebsocketConnectionManager = app .dependency_overrides .get (
458+ get_websocket_connection_manager ,
459+ get_websocket_connection_manager ,
460+ )(settings = settings )
461+
462+ arq_scheduler : Scheduler = app .dependency_overrides .get (
463+ get_scheduler ,
464+ get_scheduler ,
465+ )(settings = settings )
466+
467+ await ws_manager .connect_broadcaster ()
468+ await arq_scheduler .start (
469+ redis_host = settings .REDIS_HOST ,
470+ redis_port = settings .REDIS_PORT ,
471+ redis_password = settings .REDIS_PASSWORD ,
472+ _dependency_overrides = app .dependency_overrides ,
473+ )
474+ return arq_scheduler , ws_manager
475+
476+
477+ # We wrap the application in a function to be able to pass the settings and drop_db parameters
478+ # The drop_db parameter is used to drop the database tables before creating them again
479+ def get_application (settings : Settings , drop_db : bool = False ) -> FastAPI :
480+ # Initialize loggers
481+ LogConfig ().initialize_loggers (settings = settings )
482+
483+ hyperion_access_logger = logging .getLogger ("hyperion.access" )
484+ hyperion_security_logger = logging .getLogger ("hyperion.security" )
485+ hyperion_error_logger = logging .getLogger ("hyperion.error" )
486+
372487 # Creating a lifespan which will be called when the application starts then shuts down
373488 # https://fastapi.tiangolo.com/advanced/events/
374489 @asynccontextmanager
375490 async def lifespan (app : FastAPI ) -> AsyncGenerator :
376- # Init Google API credentials
377- google_api = GoogleAPI ()
378- if google_api .is_google_api_configured (settings ):
379- async for db in app .dependency_overrides .get (
380- get_db ,
381- get_db ,
382- )():
383- try :
384- await google_api .get_credentials (db , settings )
385- except GoogleAPIInvalidCredentialsError :
386- # We expect this error to be raised if the credentials were never set before
387- pass
388-
389- ws_manager : WebsocketConnectionManager = app .dependency_overrides .get (
390- get_websocket_connection_manager ,
391- get_websocket_connection_manager ,
392- )(settings = settings )
393-
394- arq_scheduler : Scheduler = app .dependency_overrides .get (
395- get_scheduler ,
396- get_scheduler ,
397- )(settings = settings )
398-
399- await ws_manager .connect_broadcaster ()
400- await arq_scheduler .start (
401- redis_host = settings .REDIS_HOST ,
402- redis_port = settings .REDIS_PORT ,
403- redis_password = settings .REDIS_PASSWORD ,
404- _dependency_overrides = app .dependency_overrides ,
491+ arq_scheduler , ws_manager = await init_lifespan (
492+ app = app ,
493+ settings = settings ,
494+ hyperion_error_logger = hyperion_error_logger ,
495+ drop_db = drop_db ,
405496 )
406497
407498 yield
499+
408500 hyperion_error_logger .info ("Shutting down" )
409501 await arq_scheduler .close ()
410502 await ws_manager .disconnect_broadcaster ()
@@ -429,21 +521,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
429521 calypsso = get_calypsso_app ()
430522 app .mount ("/calypsso" , calypsso , "Calypsso" )
431523
432- if settings .HYPERION_INIT_DB :
433- init_db (
434- settings = settings ,
435- hyperion_error_logger = hyperion_error_logger ,
436- drop_db = drop_db ,
437- )
438- else :
439- hyperion_error_logger .info ("Database initialization skipped" )
440-
441- # Initialize Redis
442- if not app .dependency_overrides .get (get_redis_client , get_redis_client )(
443- settings = settings ,
444- ):
445- hyperion_error_logger .info ("Redis client not configured" )
446-
447524 # We need to init the database engine to be able to use it in dependencies
448525 init_and_get_db_engine (settings )
449526
0 commit comments