55from collections .abc import AsyncGenerator , Awaitable , Callable
66from contextlib import asynccontextmanager
77from pathlib import Path
8- from typing import TYPE_CHECKING , Literal
8+ from typing import TYPE_CHECKING , cast
99
1010import alembic .command as alembic_command
1111import alembic .config as alembic_config
3333from app .core .utils .config import Settings
3434from app .core .utils .log import LogConfig
3535from app .dependencies import (
36+ disconnect_state ,
37+ get_app_state ,
3638 get_db ,
3739 get_redis_client ,
38- get_scheduler ,
39- get_websocket_connection_manager ,
40- init_and_get_db_engine ,
41- init_websocket_connection_manager ,
40+ init_app_state ,
4241)
4342from app .module import module_list
4443from app .types .exceptions import (
4948from app .types .sqlalchemy import Base
5049from app .utils import initialization
5150from app .utils .redis import limiter
51+ from app .utils .state import LifespanState
5252
5353if TYPE_CHECKING :
5454 import redis
5555
56- from app .types .scheduler import Scheduler
57- from app .types .websocket import WebsocketConnectionManager
5856
5957# NOTE: We can not get loggers at the top of this file like we do in other files
6058# as the loggers are not yet initialized
@@ -401,18 +399,30 @@ async def init_lifespan(
401399 settings : Settings ,
402400 hyperion_error_logger : logging .Logger ,
403401 drop_db : bool ,
404- ) -> tuple [ Scheduler , WebsocketConnectionManager ] :
402+ ) -> LifespanState :
405403 hyperion_error_logger .info ("Startup: Initializing application" )
406404
407- # Init the Redis client
408- redis_client : redis .Redis | bool | None = app .dependency_overrides .get (
405+ # We get `init_app_state` as a dependency, as tests
406+ # should override it to provide their own state
407+ state = await app .dependency_overrides .get (
408+ init_app_state ,
409+ init_app_state ,
410+ )(
411+ app = app ,
412+ settings = settings ,
413+ hyperion_error_logger = hyperion_error_logger ,
414+ )
415+ state = cast ("LifespanState" , state )
416+
417+ redis_client = app .dependency_overrides .get (
409418 get_redis_client ,
410419 get_redis_client ,
411- )(settings = settings )
420+ )(state = state )
412421
413422 # Initialization steps should only be run once across all workers
414423 # We use Redis locks to ensure that the initialization steps are only run once
415- if initialization .get_number_of_workers () > 1 and not isinstance (
424+ number_of_workers = initialization .get_number_of_workers ()
425+ if number_of_workers > 1 and not isinstance (
416426 redis_client ,
417427 Redis ,
418428 ):
@@ -424,6 +434,7 @@ async def init_lifespan(
424434 init_db ,
425435 "init_db" ,
426436 redis_client ,
437+ number_of_workers ,
427438 hyperion_error_logger ,
428439 unlock_key = "db_initialized" ,
429440 settings = settings ,
@@ -435,6 +446,7 @@ async def init_lifespan(
435446 test_configuration ,
436447 "test_configuration" ,
437448 redis_client ,
449+ number_of_workers ,
438450 hyperion_error_logger ,
439451 settings = settings ,
440452 hyperion_error_logger = hyperion_error_logger ,
@@ -443,35 +455,18 @@ async def init_lifespan(
443455 async for db in app .dependency_overrides .get (
444456 get_db ,
445457 get_db ,
446- )():
458+ )(state = state ):
447459 await initialization .use_lock_for_workers (
448460 init_google_API ,
449461 "init_google_API" ,
450462 redis_client ,
463+ number_of_workers ,
451464 hyperion_error_logger ,
452465 db = db ,
453466 settings = settings ,
454467 )
455468
456- init_websocket_connection_manager (WebsocketConnectionManager (settings = settings ))
457- ws_manager : WebsocketConnectionManager = app .dependency_overrides .get (
458- get_websocket_connection_manager ,
459- get_websocket_connection_manager ,
460- )(settings = settings )
461-
462- arq_scheduler : Scheduler = app .dependency_overrides .get (
463- get_scheduler ,
464- get_scheduler ,
465- )(settings = settings )
466-
467- await ws_manager .connect_broadcaster ()
468- await arq_scheduler .start (
469- redis_host = settings .REDIS_HOST ,
470- redis_port = settings .REDIS_PORT ,
471- redis_password = settings .REDIS_PASSWORD ,
472- _dependency_overrides = app .dependency_overrides ,
473- )
474- return arq_scheduler , ws_manager
469+ return state
475470
476471
477472# We wrap the application in a function to be able to pass the settings and drop_db parameters
@@ -487,19 +482,24 @@ def get_application(settings: Settings, drop_db: bool = False) -> FastAPI:
487482 # Creating a lifespan which will be called when the application starts then shuts down
488483 # https://fastapi.tiangolo.com/advanced/events/
489484 @asynccontextmanager
490- async def lifespan (app : FastAPI ) -> AsyncGenerator :
491- arq_scheduler , ws_manager = await init_lifespan (
485+ async def lifespan (app : FastAPI ) -> AsyncGenerator [ LifespanState , None ] :
486+ state = await init_lifespan (
492487 app = app ,
493488 settings = settings ,
494489 hyperion_error_logger = hyperion_error_logger ,
495490 drop_db = drop_db ,
496491 )
497492
498- yield
493+ yield state
499494
500495 hyperion_error_logger .info ("Shutting down" )
501- await arq_scheduler .close ()
502- await ws_manager .disconnect_broadcaster ()
496+ await app .dependency_overrides .get (
497+ disconnect_state ,
498+ disconnect_state ,
499+ )(
500+ state = state ,
501+ hyperion_error_logger = hyperion_error_logger ,
502+ )
503503
504504 # Initialize app
505505 app = FastAPI (
@@ -521,8 +521,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
521521 calypsso = get_calypsso_app ()
522522 app .mount ("/calypsso" , calypsso , "Calypsso" )
523523
524- # We need to init the database engine to be able to use it in dependencies
525- init_and_get_db_engine (settings )
524+ get_app_state_dependency = app .dependency_overrides .get (
525+ get_app_state ,
526+ get_app_state ,
527+ )
528+ get_redis_client_dependency = app .dependency_overrides .get (
529+ get_redis_client ,
530+ get_redis_client ,
531+ )
526532
527533 @app .middleware ("http" )
528534 async def logging_middleware (
@@ -540,6 +546,7 @@ async def logging_middleware(
540546 # This identifier will allow combining logs associated with the same request
541547 # https://www.starlette.io/requests/#other-state
542548 request_id = str (uuid .uuid4 ())
549+
543550 request .state .request_id = request_id
544551
545552 # This should never happen, but we log it just in case
@@ -553,16 +560,13 @@ async def logging_middleware(
553560 port = request .client .port
554561 client_address = f"{ ip_address } :{ port } "
555562
556- redis_client : redis .Redis | Literal [False ] | None = (
557- app .dependency_overrides .get (
558- get_redis_client ,
559- get_redis_client ,
560- )(settings = settings )
563+ redis_client : redis .Redis | None = get_redis_client_dependency (
564+ state = get_app_state_dependency (request ),
561565 )
562566
563567 # We test the ip address with the redis limiter
564568 process = True
565- if redis_client : # If redis is configured
569+ if redis_client and settings . ENABLE_RATE_LIMITER : # If redis is configured
566570 process , log = limiter (
567571 redis_client ,
568572 ip_address ,
0 commit comments