1+ from collections .abc import AsyncGenerator , Callable
2+ from contextlib import _AsyncGeneratorContextManager , asynccontextmanager
13from typing import Any
24
35import anyio
@@ -68,6 +70,50 @@ async def set_threadpool_tokens(number_of_tokens: int = 100) -> None:
6870 limiter .total_tokens = number_of_tokens
6971
7072
73+ def lifespan_factory (
74+ settings : (
75+ DatabaseSettings
76+ | RedisCacheSettings
77+ | AppSettings
78+ | ClientSideCacheSettings
79+ | RedisQueueSettings
80+ | RedisRateLimiterSettings
81+ | EnvironmentSettings
82+ ),
83+ create_tables_on_start : bool = True ,
84+ ) -> Callable [[FastAPI ], _AsyncGeneratorContextManager [Any ]]:
85+ """Factory to create a lifespan async context manager for a FastAPI app."""
86+
87+ @asynccontextmanager
88+ async def lifespan (app : FastAPI ) -> AsyncGenerator :
89+ await set_threadpool_tokens ()
90+
91+ if isinstance (settings , DatabaseSettings ) and create_tables_on_start :
92+ await create_tables ()
93+
94+ if isinstance (settings , RedisCacheSettings ):
95+ await create_redis_cache_pool ()
96+
97+ if isinstance (settings , RedisQueueSettings ):
98+ await create_redis_queue_pool ()
99+
100+ if isinstance (settings , RedisRateLimiterSettings ):
101+ await create_redis_rate_limit_pool ()
102+
103+ yield
104+
105+ if isinstance (settings , RedisCacheSettings ):
106+ await close_redis_cache_pool ()
107+
108+ if isinstance (settings , RedisQueueSettings ):
109+ await close_redis_queue_pool ()
110+
111+ if isinstance (settings , RedisRateLimiterSettings ):
112+ await close_redis_rate_limit_pool ()
113+
114+ return lifespan
115+
116+
71117# -------------- application --------------
72118def create_application (
73119 router : APIRouter ,
@@ -136,30 +182,13 @@ def create_application(
136182 if isinstance (settings , EnvironmentSettings ):
137183 kwargs .update ({"docs_url" : None , "redoc_url" : None , "openapi_url" : None })
138184
139- application = FastAPI (** kwargs )
140-
141- # --- application created ---
142- application .include_router (router )
143- application .add_event_handler ("startup" , set_threadpool_tokens )
185+ lifespan = lifespan_factory (settings , create_tables_on_start = create_tables_on_start )
144186
145- if isinstance (settings , DatabaseSettings ) and create_tables_on_start :
146- application .add_event_handler ("startup" , create_tables )
147-
148- if isinstance (settings , RedisCacheSettings ):
149- application .add_event_handler ("startup" , create_redis_cache_pool )
150- application .add_event_handler ("shutdown" , close_redis_cache_pool )
187+ application = FastAPI (lifespan = lifespan , ** kwargs )
151188
152189 if isinstance (settings , ClientSideCacheSettings ):
153190 application .add_middleware (ClientCacheMiddleware , max_age = settings .CLIENT_CACHE_MAX_AGE )
154191
155- if isinstance (settings , RedisQueueSettings ):
156- application .add_event_handler ("startup" , create_redis_queue_pool )
157- application .add_event_handler ("shutdown" , close_redis_queue_pool )
158-
159- if isinstance (settings , RedisRateLimiterSettings ):
160- application .add_event_handler ("startup" , create_redis_rate_limit_pool )
161- application .add_event_handler ("shutdown" , close_redis_rate_limit_pool )
162-
163192 if isinstance (settings , EnvironmentSettings ):
164193 if settings .ENVIRONMENT != EnvironmentOption .PRODUCTION :
165194 docs_router = APIRouter ()
@@ -181,4 +210,4 @@ async def openapi() -> dict[str, Any]:
181210
182211 application .include_router (docs_router )
183212
184- return application
213+ return application
0 commit comments