Skip to content

Commit efb3d0f

Browse files
Refactor app state using dependencies and state (#769)
### Description Instead of initiating objects like redis client or db connection when the dependency is called for the first time, we create a dedicated `init_app_state` dependency to initialize all these objects. These objects are stored in the application state instead of global variables. This ensures that an object may not be reused in another Hyperion app instance, which would have its own event loop. This previously caused issues for tests Requires #768 --------- Co-authored-by: Thonyk <[email protected]>
1 parent bec30ba commit efb3d0f

File tree

16 files changed

+503
-400
lines changed

16 files changed

+503
-400
lines changed

.env.test.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ AUTH_CLIENTS_DICT:
5151
# Logging configuration #
5252

5353
LOG_DEBUG_MESSAGES: true
54+
ENABLE_RATE_LIMITER: false
5455

5556
# CORS_ORIGINS should be a list of urls allowed to make requests to the API
5657
# It should begin with 'http://' or 'https:// and should never end with a '/'

app/app.py

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import AsyncGenerator, Awaitable, Callable
66
from contextlib import asynccontextmanager
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Literal
8+
from typing import TYPE_CHECKING, cast
99

1010
import alembic.command as alembic_command
1111
import alembic.config as alembic_config
@@ -33,12 +33,11 @@
3333
from app.core.utils.config import Settings
3434
from app.core.utils.log import LogConfig
3535
from app.dependencies import (
36+
disconnect_state,
37+
get_app_state,
3638
get_db,
3739
get_redis_client,
38-
get_scheduler,
39-
get_websocket_connection_manager,
40-
init_and_get_db_engine,
41-
init_websocket_connection_manager,
40+
init_app_state,
4241
)
4342
from app.module import module_list
4443
from app.types.exceptions import (
@@ -49,12 +48,11 @@
4948
from app.types.sqlalchemy import Base
5049
from app.utils import initialization
5150
from app.utils.redis import limiter
51+
from app.utils.state import LifespanState
5252

5353
if TYPE_CHECKING:
5454
import redis
5555

56-
from app.types.scheduler import Scheduler
57-
from app.types.websocket import WebsocketConnectionManager
5856

5957
# NOTE: We can not get loggers at the top of this file like we do in other files
6058
# as the loggers are not yet initialized
@@ -401,18 +399,30 @@ async def init_lifespan(
401399
settings: Settings,
402400
hyperion_error_logger: logging.Logger,
403401
drop_db: bool,
404-
) -> tuple[Scheduler, WebsocketConnectionManager]:
402+
) -> LifespanState:
405403
hyperion_error_logger.info("Startup: Initializing application")
406404

407-
# Init the Redis client
408-
redis_client: redis.Redis | bool | None = app.dependency_overrides.get(
405+
# We get `init_app_state` as a dependency, as tests
406+
# should override it to provide their own state
407+
state = await app.dependency_overrides.get(
408+
init_app_state,
409+
init_app_state,
410+
)(
411+
app=app,
412+
settings=settings,
413+
hyperion_error_logger=hyperion_error_logger,
414+
)
415+
state = cast("LifespanState", state)
416+
417+
redis_client = app.dependency_overrides.get(
409418
get_redis_client,
410419
get_redis_client,
411-
)(settings=settings)
420+
)(state=state)
412421

413422
# Initialization steps should only be run once across all workers
414423
# We use Redis locks to ensure that the initialization steps are only run once
415-
if initialization.get_number_of_workers() > 1 and not isinstance(
424+
number_of_workers = initialization.get_number_of_workers()
425+
if number_of_workers > 1 and not isinstance(
416426
redis_client,
417427
Redis,
418428
):
@@ -424,6 +434,7 @@ async def init_lifespan(
424434
init_db,
425435
"init_db",
426436
redis_client,
437+
number_of_workers,
427438
hyperion_error_logger,
428439
unlock_key="db_initialized",
429440
settings=settings,
@@ -435,6 +446,7 @@ async def init_lifespan(
435446
test_configuration,
436447
"test_configuration",
437448
redis_client,
449+
number_of_workers,
438450
hyperion_error_logger,
439451
settings=settings,
440452
hyperion_error_logger=hyperion_error_logger,
@@ -443,35 +455,18 @@ async def init_lifespan(
443455
async for db in app.dependency_overrides.get(
444456
get_db,
445457
get_db,
446-
)():
458+
)(state=state):
447459
await initialization.use_lock_for_workers(
448460
init_google_API,
449461
"init_google_API",
450462
redis_client,
463+
number_of_workers,
451464
hyperion_error_logger,
452465
db=db,
453466
settings=settings,
454467
)
455468

456-
init_websocket_connection_manager(WebsocketConnectionManager(settings=settings))
457-
ws_manager: WebsocketConnectionManager = app.dependency_overrides.get(
458-
get_websocket_connection_manager,
459-
get_websocket_connection_manager,
460-
)(settings=settings)
461-
462-
arq_scheduler: Scheduler = app.dependency_overrides.get(
463-
get_scheduler,
464-
get_scheduler,
465-
)(settings=settings)
466-
467-
await ws_manager.connect_broadcaster()
468-
await arq_scheduler.start(
469-
redis_host=settings.REDIS_HOST,
470-
redis_port=settings.REDIS_PORT,
471-
redis_password=settings.REDIS_PASSWORD,
472-
_dependency_overrides=app.dependency_overrides,
473-
)
474-
return arq_scheduler, ws_manager
469+
return state
475470

476471

477472
# We wrap the application in a function to be able to pass the settings and drop_db parameters
@@ -487,19 +482,24 @@ def get_application(settings: Settings, drop_db: bool = False) -> FastAPI:
487482
# Creating a lifespan which will be called when the application starts then shuts down
488483
# https://fastapi.tiangolo.com/advanced/events/
489484
@asynccontextmanager
490-
async def lifespan(app: FastAPI) -> AsyncGenerator:
491-
arq_scheduler, ws_manager = await init_lifespan(
485+
async def lifespan(app: FastAPI) -> AsyncGenerator[LifespanState, None]:
486+
state = await init_lifespan(
492487
app=app,
493488
settings=settings,
494489
hyperion_error_logger=hyperion_error_logger,
495490
drop_db=drop_db,
496491
)
497492

498-
yield
493+
yield state
499494

500495
hyperion_error_logger.info("Shutting down")
501-
await arq_scheduler.close()
502-
await ws_manager.disconnect_broadcaster()
496+
await app.dependency_overrides.get(
497+
disconnect_state,
498+
disconnect_state,
499+
)(
500+
state=state,
501+
hyperion_error_logger=hyperion_error_logger,
502+
)
503503

504504
# Initialize app
505505
app = FastAPI(
@@ -521,8 +521,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
521521
calypsso = get_calypsso_app()
522522
app.mount("/calypsso", calypsso, "Calypsso")
523523

524-
# We need to init the database engine to be able to use it in dependencies
525-
init_and_get_db_engine(settings)
524+
get_app_state_dependency = app.dependency_overrides.get(
525+
get_app_state,
526+
get_app_state,
527+
)
528+
get_redis_client_dependency = app.dependency_overrides.get(
529+
get_redis_client,
530+
get_redis_client,
531+
)
526532

527533
@app.middleware("http")
528534
async def logging_middleware(
@@ -540,6 +546,7 @@ async def logging_middleware(
540546
# This identifier will allow combining logs associated with the same request
541547
# https://www.starlette.io/requests/#other-state
542548
request_id = str(uuid.uuid4())
549+
543550
request.state.request_id = request_id
544551

545552
# This should never happen, but we log it just in case
@@ -553,16 +560,13 @@ async def logging_middleware(
553560
port = request.client.port
554561
client_address = f"{ip_address}:{port}"
555562

556-
redis_client: redis.Redis | Literal[False] | None = (
557-
app.dependency_overrides.get(
558-
get_redis_client,
559-
get_redis_client,
560-
)(settings=settings)
563+
redis_client: redis.Redis | None = get_redis_client_dependency(
564+
state=get_app_state_dependency(request),
561565
)
562566

563567
# We test the ip address with the redis limiter
564568
process = True
565-
if redis_client: # If redis is configured
569+
if redis_client and settings.ENABLE_RATE_LIMITER: # If redis is configured
566570
process, log = limiter(
567571
redis_client,
568572
ip_address,

app/core/core_endpoints/endpoints_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from os import path
33
from pathlib import Path
44

5-
from fastapi import APIRouter, Depends, HTTPException
5+
from fastapi import APIRouter, Depends, HTTPException, Request
66
from fastapi.responses import FileResponse
77
from sqlalchemy.ext.asyncio import AsyncSession
88

@@ -37,6 +37,7 @@
3737
status_code=200,
3838
)
3939
async def read_information(
40+
request: Request,
4041
settings: Settings = Depends(get_settings),
4142
):
4243
"""

app/core/utils/config.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,11 +193,15 @@ def settings_customise_sources(
193193
# Redis configuration is needed to use the rate limiter, or multiple uvicorn workers
194194
# We use the default redis configuration, so the protected mode is enabled by default (see https://redis.io/docs/manual/security/#protected-mode)
195195
# If you want to use a custom configuration, a password and a specific binds should be used to avoid security issues
196-
REDIS_HOST: str
197-
REDIS_PORT: int
196+
REDIS_HOST: str | None = None
197+
REDIS_PORT: int = 6379
198198
REDIS_PASSWORD: str | None = None
199-
REDIS_LIMIT: int
200-
REDIS_WINDOW: int
199+
REDIS_LIMIT: int = 1000
200+
REDIS_WINDOW: int = 60
201+
202+
# Rate limit requests based on REDIS_LIMIT and REDIS_WINDOW
203+
# A working Redis client is required to use the rate limiter
204+
ENABLE_RATE_LIMITER: bool = True
201205

202206
##########################
203207
# Firebase Configuration #

0 commit comments

Comments
 (0)