Skip to content

Commit b16033f

Browse files
Fix(state management): store state in global variable to allow the sc… (#811)
…heduler to access db (backport of ProximApp#8) * Fix user batch invitation response model * Fix: add missing param in send_emails_from_queue_task * Get db directly using SessionLocal * Store state in global Python variable * Use arq 0.26.3 * Don't keep arq job results after completion to be able to queue new jobs with the same id * Cancel planned notification with the same job_id before queuing a new one * fixup state * Access GLOBAL_STATE in tests init * Remove unexpected state param while disconnecting * Lint * Parametrize test_factory fixture * Lint * Refactor test settings
1 parent 25f3217 commit b16033f

File tree

16 files changed

+212
-183
lines changed

16 files changed

+212
-183
lines changed

app/app.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,10 @@
3636
from app.core.utils.log import LogConfig
3737
from app.dependencies import (
3838
disconnect_state,
39-
get_app_state,
4039
get_db,
4140
get_notification_manager,
4241
get_redis_client,
43-
init_app_state,
42+
init_state,
4443
)
4544
from app.module import all_modules, module_list
4645
from app.types.exceptions import (
@@ -494,11 +493,11 @@ async def init_lifespan(
494493
) -> LifespanState:
495494
hyperion_error_logger.info("Startup: Initializing application")
496495

497-
# We get `init_app_state` as a dependency, as tests
496+
# We get `init_state` as a dependency, as tests
498497
# should override it to provide their own state
499-
state: LifespanState = await app.dependency_overrides.get(
500-
init_app_state,
501-
init_app_state,
498+
await app.dependency_overrides.get(
499+
init_state,
500+
init_state,
502501
)(
503502
app=app,
504503
settings=settings,
@@ -508,7 +507,7 @@ async def init_lifespan(
508507
redis_client: Redis | None = app.dependency_overrides.get(
509508
get_redis_client,
510509
get_redis_client,
511-
)(state=state)
510+
)()
512511

513512
# Initialization steps should only be run once across all workers
514513
# We use Redis locks to ensure that the initialization steps are only run once
@@ -544,14 +543,14 @@ async def init_lifespan(
544543
)
545544

546545
get_db_dependency: Callable[
547-
[LifespanState],
546+
[],
548547
AsyncGenerator[AsyncSession, None],
549548
] = app.dependency_overrides.get(
550549
get_db,
551550
get_db,
552551
)
553552
# We need to run the factories only once across all the workers
554-
async for db in get_db_dependency(state):
553+
async for db in get_db_dependency():
555554
await initialization.use_lock_for_workers(
556555
run_factories,
557556
"run_factories",
@@ -562,7 +561,7 @@ async def init_lifespan(
562561
settings=settings,
563562
hyperion_error_logger=hyperion_error_logger,
564563
)
565-
async for db in get_db_dependency(state):
564+
async for db in get_db_dependency():
566565
await initialization.use_lock_for_workers(
567566
init_google_API,
568567
"init_google_API",
@@ -573,11 +572,11 @@ async def init_lifespan(
573572
settings=settings,
574573
)
575574

576-
async for db in get_db_dependency(state):
575+
async for db in get_db_dependency():
577576
notification_manager = app.dependency_overrides.get(
578577
get_notification_manager,
579578
get_notification_manager,
580-
)(state)
579+
)()
581580
await initialization.use_lock_for_workers(
582581
initialize_notification_topics,
583582
"initialize_notification_topics",
@@ -589,7 +588,7 @@ async def init_lifespan(
589588
notification_manager=notification_manager,
590589
)
591590

592-
return state
591+
return LifespanState()
593592

594593

595594
# We wrap the application in a function to be able to pass the settings and drop_db parameters
@@ -620,7 +619,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[LifespanState, None]:
620619
disconnect_state,
621620
disconnect_state,
622621
)(
623-
state=state,
624622
hyperion_error_logger=hyperion_error_logger,
625623
)
626624

@@ -644,10 +642,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[LifespanState, None]:
644642
calypsso = get_calypsso_app()
645643
app.mount("/calypsso", calypsso, "Calypsso")
646644

647-
get_app_state_dependency = app.dependency_overrides.get(
648-
get_app_state,
649-
get_app_state,
650-
)
651645
get_redis_client_dependency = app.dependency_overrides.get(
652646
get_redis_client,
653647
get_redis_client,
@@ -683,9 +677,7 @@ async def logging_middleware(
683677
port = request.client.port
684678
client_address = f"{ip_address}:{port}"
685679

686-
redis_client: redis.Redis | None = get_redis_client_dependency(
687-
state=get_app_state_dependency(request),
688-
)
680+
redis_client: redis.Redis | None = get_redis_client_dependency()
689681

690682
# We test the ip address with the redis limiter
691683
process = True

app/core/notification/endpoints_notification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ async def send_test_future_notification(
299299
message=message,
300300
defer_date=datetime.now(UTC) + timedelta(seconds=10),
301301
scheduler=scheduler,
302-
job_id="testtt",
302+
job_id="send_test_future_notification",
303303
)
304304

305305

@@ -350,7 +350,7 @@ async def send_test_future_notification_topic(
350350
topic_id=notification_test_topic.id,
351351
message=message,
352352
defer_date=datetime.now(UTC) + timedelta(seconds=10),
353-
job_id="test26",
353+
job_id="notification_test_future",
354354
scheduler=scheduler,
355355
)
356356

app/dependencies.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ async def get_users(db: AsyncSession = Depends(get_db)):
3939
from app.utils.auth import auth_utils
4040
from app.utils.communication.notifications import NotificationManager, NotificationTool
4141
from app.utils.state import (
42-
LifespanState,
42+
GlobalState,
4343
RuntimeLifespanState,
4444
disconnect_redis_client,
4545
disconnect_scheduler,
@@ -61,28 +61,31 @@ async def get_users(db: AsyncSession = Depends(get_db)):
6161
hyperion_access_logger = logging.getLogger("hyperion.access")
6262
hyperion_error_logger = logging.getLogger("hyperion.error")
6363

64+
GLOBAL_STATE: GlobalState
6465

65-
async def init_app_state(
66+
67+
async def init_state(
6668
app: FastAPI,
6769
settings: Settings,
6870
hyperion_error_logger: logging.Logger,
69-
) -> LifespanState:
71+
) -> None:
7072
"""
71-
Initialize the state of the application. This dependency should be used at the start of the application lifespan.
73+
Initialize the global state for the project. This dependency should be called at the start of the application lifespan.
7274
7375
This methode should be called as a dependency, and test may override it to provide their own state.
7476
```python
75-
state = app.dependency_overrides.get(
76-
init_app_state,
77-
init_app_state,
77+
app.dependency_overrides.get(
78+
init_state,
79+
init_state,
7880
)(
7981
app=app,
8082
settings=settings,
8183
hyperion_error_logger=hyperion_error_logger,
8284
)
83-
state = cast("LifespanState", state)
8485
```
8586
"""
87+
global GLOBAL_STATE
88+
8689
engine = init_engine(settings=settings)
8790

8891
SessionLocal = init_SessionLocal(engine)
@@ -94,7 +97,7 @@ async def init_app_state(
9497

9598
scheduler = await init_scheduler(
9699
settings=settings,
97-
app=app,
100+
_dependency_overrides=app.dependency_overrides,
98101
)
99102

100103
ws_manager = await init_websocket_connection_manager(
@@ -112,7 +115,7 @@ async def init_app_state(
112115

113116
mail_templates = init_mail_templates(settings=settings)
114117

115-
return LifespanState(
118+
GLOBAL_STATE = GlobalState(
116119
engine=engine,
117120
SessionLocal=SessionLocal,
118121
redis_client=redis_client,
@@ -126,17 +129,17 @@ async def init_app_state(
126129

127130

128131
async def disconnect_state(
129-
state: LifespanState,
130132
hyperion_error_logger: logging.Logger,
131133
) -> None:
132134
"""
133135
Disconnect items requiring it. This dependency should be used at the end of the application lifespan.
134136
135-
This methode should be called as a dependency as test may need to run additional steps
137+
This methode should be called as a dependency as tests may need to run additional steps
136138
"""
137-
disconnect_redis_client(state["redis_client"])
138-
await disconnect_scheduler(state["scheduler"])
139-
await disconnect_websocket_connection_manager(state["ws_manager"])
139+
140+
disconnect_redis_client(GLOBAL_STATE["redis_client"])
141+
await disconnect_scheduler(GLOBAL_STATE["scheduler"])
142+
await disconnect_websocket_connection_manager(GLOBAL_STATE["ws_manager"])
140143

141144
hyperion_error_logger.info("Application state disconnected successfully.")
142145

@@ -179,7 +182,7 @@ def get_settings() -> Settings:
179182
return construct_prod_settings()
180183

181184

182-
async def get_db(state: AppState) -> AsyncGenerator[AsyncSession, None]:
185+
async def get_db() -> AsyncGenerator[AsyncSession, None]:
183186
"""
184187
Return a database session that will be automatically committed and closed after usage.
185188
@@ -202,7 +205,7 @@ async def get_db(state: AppState) -> AsyncGenerator[AsyncSession, None]:
202205
# Add objects that may be rolled back in case of an error here
203206
```
204207
"""
205-
async with state["SessionLocal"]() as db:
208+
async with GLOBAL_STATE["SessionLocal"]() as db:
206209
try:
207210
yield db
208211
except HTTPException:
@@ -217,42 +220,42 @@ async def get_db(state: AppState) -> AsyncGenerator[AsyncSession, None]:
217220
await db.close()
218221

219222

220-
async def get_unsafe_db(state: AppState) -> AsyncGenerator[AsyncSession, None]:
223+
async def get_unsafe_db() -> AsyncGenerator[AsyncSession, None]:
221224
"""
222225
Return a database session but don't close it automatically
223226
224227
It should only be used for really specific cases where `get_db` will not work
225228
"""
226229

227-
async with state["SessionLocal"]() as db:
230+
async with GLOBAL_STATE["SessionLocal"]() as db:
228231
yield db
229232

230233

231-
def get_redis_client(state: AppState) -> redis.Redis | None:
234+
def get_redis_client() -> redis.Redis | None:
232235
"""
233236
Dependency that returns the redis client
234237
235238
If the redis client is not available, it will return None.
236239
"""
237-
return state["redis_client"]
240+
return GLOBAL_STATE["redis_client"]
238241

239242

240-
def get_scheduler(state: AppState) -> Scheduler:
241-
return state["scheduler"]
243+
def get_scheduler() -> Scheduler:
244+
return GLOBAL_STATE["scheduler"]
242245

243246

244-
def get_websocket_connection_manager(state: AppState) -> WebsocketConnectionManager:
245-
return state["ws_manager"]
247+
def get_websocket_connection_manager() -> WebsocketConnectionManager:
248+
return GLOBAL_STATE["ws_manager"]
246249

247250

248-
def get_notification_manager(state: AppState) -> NotificationManager:
251+
def get_notification_manager() -> NotificationManager:
249252
"""
250253
Dependency that returns the notification manager.
251254
This dependency provide a low level tool allowing to use notification manager internal methods.
252255
253256
If you want to send a notification, prefer `get_notification_tool` dependency.
254257
"""
255-
return state["notification_manager"]
258+
return GLOBAL_STATE["notification_manager"]
256259

257260

258261
def get_notification_tool(
@@ -271,22 +274,20 @@ def get_notification_tool(
271274
)
272275

273276

274-
def get_drive_file_manager(state: AppState) -> DriveFileManager:
277+
def get_drive_file_manager() -> DriveFileManager:
275278
"""
276279
Dependency that returns the drive file manager.
277280
"""
278281

279-
return state["drive_file_manager"]
282+
return GLOBAL_STATE["drive_file_manager"]
280283

281284

282285
@lru_cache
283286
def get_payment_tool(
284287
name: HelloAssoConfigName,
285-
) -> Callable[[AppState], PaymentTool]:
286-
def get_payment_tool(
287-
state: AppState,
288-
) -> PaymentTool:
289-
payment_tools = state["payment_tools"]
288+
) -> Callable[[], PaymentTool]:
289+
def get_payment_tool() -> PaymentTool:
290+
payment_tools = GLOBAL_STATE["payment_tools"]
290291
if name not in payment_tools:
291292
hyperion_error_logger.warning(
292293
f"HelloAsso API credentials are not set for {name.value}, payment won't be available",
@@ -298,14 +299,12 @@ def get_payment_tool(
298299
return get_payment_tool
299300

300301

301-
def get_mail_templates(
302-
state: AppState,
303-
) -> calypsso.MailTemplates:
302+
def get_mail_templates() -> calypsso.MailTemplates:
304303
"""
305304
Dependency that returns the mail templates manager.
306305
"""
307306

308-
return state["mail_templates"]
307+
return GLOBAL_STATE["mail_templates"]
309308

310309

311310
def get_token_data(

0 commit comments

Comments
 (0)