|
8 | 8 |
|
9 | 9 |
|
10 | 10 | from app.api.dependencies import get_current_superuser |
11 | | -from app.core import cache, queue |
| 11 | +from app.core import cache, queue, rate_limit |
12 | 12 | from app.core.config import settings |
13 | 13 | from app.core.database import Base |
14 | 14 | from app.core.database import async_engine as engine |
|
18 | 18 | AppSettings, |
19 | 19 | ClientSideCacheSettings, |
20 | 20 | RedisQueueSettings, |
| 21 | + RedisRateLimiterSettings, |
21 | 22 | EnvironmentOption, |
22 | 23 | EnvironmentSettings |
23 | 24 | ) |
@@ -49,6 +50,16 @@ async def close_redis_queue_pool(): |
49 | 50 | await queue.pool.aclose() |
50 | 51 |
|
51 | 52 |
|
| 53 | +# -------------- rate limit -------------- |
| 54 | +async def create_redis_rate_limit_pool(): |
| 55 | + rate_limit.pool = redis.ConnectionPool.from_url(settings.REDIS_RATE_LIMIT_URL) |
| 56 | + rate_limit.client = redis.Redis.from_pool(rate_limit.pool) |
| 57 | + |
| 58 | + |
| 59 | +async def close_redis_rate_limit_pool(): |
| 60 | + await rate_limit.client.aclose() |
| 61 | + |
| 62 | + |
52 | 63 | # -------------- application -------------- |
53 | 64 | async def set_threadpool_tokens(number_of_tokens=100): |
54 | 65 | limiter = anyio.to_thread.current_default_thread_limiter() |
@@ -133,6 +144,10 @@ def create_application(router: APIRouter, settings, **kwargs) -> FastAPI: |
133 | 144 | application.add_event_handler("startup", create_redis_queue_pool) |
134 | 145 | application.add_event_handler("shutdown", close_redis_queue_pool) |
135 | 146 |
|
| 147 | + if isinstance(settings, RedisRateLimiterSettings): |
| 148 | + application.add_event_handler("startup", create_redis_rate_limit_pool) |
| 149 | + application.add_event_handler("shutdown", close_redis_rate_limit_pool) |
| 150 | + |
136 | 151 | if isinstance(settings, EnvironmentSettings): |
137 | 152 | if settings.ENVIRONMENT != EnvironmentOption.PRODUCTION: |
138 | 153 | docs_router = APIRouter() |
|
0 commit comments