11from typing import Annotated
22
33from app .core .security import SECRET_KEY , ALGORITHM , oauth2_scheme
4- from app .core .config import RedisRateLimiterSettings , settings
4+ from app .core .config import settings
55
66from sqlalchemy .ext .asyncio import AsyncSession
77from jose import JWTError , jwt
88from fastapi import (
99 Depends ,
1010 HTTPException ,
11- Request ,
12- status
11+ Request
1312)
1413
1514from app .core .database import async_get_db
1615from app .core .models import TokenData
17- # from app.core.rate_limit import is_rate_limited
16+ from app .core .rate_limit import is_rate_limited
17+ from app .core .logger import logging
1818from app .models .user import User
1919from app .api .exceptions import credentials_exception , privileges_exception
2020from app .crud .crud_users import crud_users
2121from app .crud .crud_tier import crud_tiers
22+ from app .crud .crud_rate_limit import crud_rate_limits
23+ from app .schemas .rate_limit import sanitize_path
2224
23- async def get_current_user (token : Annotated [str , Depends (oauth2_scheme )], db : Annotated [AsyncSession , Depends (async_get_db )]) -> User :
25+
26+ logger = logging .getLogger (__name__ )
27+
28+ DEFAULT_LIMIT = settings .DEFAULT_RATE_LIMIT_LIMIT
29+ DEFAULT_PERIOD = settings .DEFAULT_RATE_LIMIT_PERIOD
30+
31+ async def get_current_user (
32+ token : Annotated [str , Depends (oauth2_scheme )],
33+ db : Annotated [AsyncSession , Depends (async_get_db )]
34+ ) -> User | None :
2435 try :
2536 payload = jwt .decode (token , SECRET_KEY , algorithms = [ALGORITHM ])
2637 username_or_email : str = payload .get ("sub" )
2738 if username_or_email is None :
2839 raise credentials_exception
2940 token_data = TokenData (username_or_email = username_or_email )
41+
3042 except JWTError :
3143 raise credentials_exception
3244
@@ -35,16 +47,80 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], db: An
3547 else :
3648 user = await crud_users .get (db = db , username = token_data .username_or_email )
3749
38- if user is None :
39- raise credentials_exception
50+ if user and not user [ "is_deleted" ] :
51+ return user
4052
41- if user .is_deleted :
42- raise HTTPException (status_code = 400 , detail = "User deleted" )
53+ raise credentials_exception
54+
55+
56+ async def get_optional_user (
57+ request : Request ,
58+ db : AsyncSession = Depends (async_get_db )
59+ ) -> User | None :
60+ token = request .headers .get ("Authorization" )
61+ if not token :
62+ return None
63+
64+ try :
65+ token_type , _ , token_value = token .partition (' ' )
66+ if token_type .lower () != 'bearer' or not token_value :
67+ return None
68+
69+ return await get_current_user (token_value , db )
70+
71+ except HTTPException as http_exc :
72+ if http_exc .status_code != 401 :
73+ logger .error (f"Unexpected HTTPException in get_optional_user: { http_exc .detail } " )
74+ return None
75+
76+ except Exception as exc :
77+ logger .error (f"Unexpected error in get_optional_user: { exc } " )
78+ return None
4379
44- return user
4580
4681async def get_current_superuser (current_user : Annotated [User , Depends (get_current_user )]) -> User :
47- if not current_user . is_superuser :
82+ if not current_user [ " is_superuser" ] :
4883 raise privileges_exception
4984
5085 return current_user
86+
87+
88+ async def rate_limiter (
89+ request : Request ,
90+ db : Annotated [AsyncSession , Depends (async_get_db )],
91+ user : User | None = Depends (get_optional_user )
92+ ):
93+ path = sanitize_path (request .url .path )
94+ if user :
95+ user_id = user ["id" ]
96+ tier = await crud_tiers .get (db , id = user ["tier_id" ])
97+ if tier :
98+ rate_limit = await crud_rate_limits .get (
99+ db = db ,
100+ tier_id = tier ["id" ],
101+ path = path
102+ )
103+ if rate_limit :
104+ limit , period = rate_limit ["limit" ], rate_limit ["period" ]
105+ else :
106+ logger .warning (f"User { user_id } with tier '{ tier ['name' ]} ' has no specific rate limit for path '{ path } '. Applying default rate limit." )
107+ limit , period = DEFAULT_LIMIT , DEFAULT_PERIOD
108+ else :
109+ logger .warning (f"User { user_id } has no assigned tier. Applying default rate limit." )
110+ limit , period = DEFAULT_LIMIT , DEFAULT_PERIOD
111+ else :
112+ user_id = request .client .host
113+ limit , period = DEFAULT_LIMIT , DEFAULT_PERIOD
114+
115+ is_limited = await is_rate_limited (
116+ db = db ,
117+ user_id = user_id ,
118+ path = path ,
119+ limit = limit ,
120+ period = period
121+ )
122+ if is_limited :
123+ raise HTTPException (
124+ status_code = 429 ,
125+ detail = "Rate limit exceeded"
126+ )
0 commit comments