Skip to content

Commit 531e3cf

Browse files
committed
rate limiter dependency ready and working. Cache decorator updated
1 parent dfd3811 commit 531e3cf

File tree

16 files changed

+338
-63
lines changed

16 files changed

+338
-63
lines changed

src/app/api/dependencies.py

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,44 @@
11
from typing import Annotated
22

33
from app.core.security import SECRET_KEY, ALGORITHM, oauth2_scheme
4-
from app.core.config import RedisRateLimiterSettings, settings
4+
from app.core.config import settings
55

66
from sqlalchemy.ext.asyncio import AsyncSession
77
from jose import JWTError, jwt
88
from fastapi import (
99
Depends,
1010
HTTPException,
11-
Request,
12-
status
11+
Request
1312
)
1413

1514
from app.core.database import async_get_db
1615
from app.core.models import TokenData
17-
# from app.core.rate_limit import is_rate_limited
16+
from app.core.rate_limit import is_rate_limited
17+
from app.core.logger import logging
1818
from app.models.user import User
1919
from app.api.exceptions import credentials_exception, privileges_exception
2020
from app.crud.crud_users import crud_users
2121
from app.crud.crud_tier import crud_tiers
22+
from app.crud.crud_rate_limit import crud_rate_limits
23+
from app.schemas.rate_limit import sanitize_path
2224

23-
async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], db: Annotated[AsyncSession, Depends(async_get_db)]) -> User:
25+
26+
logger = logging.getLogger(__name__)
27+
28+
DEFAULT_LIMIT = settings.DEFAULT_RATE_LIMIT_LIMIT
29+
DEFAULT_PERIOD = settings.DEFAULT_RATE_LIMIT_PERIOD
30+
31+
async def get_current_user(
32+
token: Annotated[str, Depends(oauth2_scheme)],
33+
db: Annotated[AsyncSession, Depends(async_get_db)]
34+
) -> User | None:
2435
try:
2536
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
2637
username_or_email: str = payload.get("sub")
2738
if username_or_email is None:
2839
raise credentials_exception
2940
token_data = TokenData(username_or_email=username_or_email)
41+
3042
except JWTError:
3143
raise credentials_exception
3244

@@ -35,16 +47,80 @@ async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], db: An
3547
else:
3648
user = await crud_users.get(db=db, username=token_data.username_or_email)
3749

38-
if user is None:
39-
raise credentials_exception
50+
if user and not user["is_deleted"]:
51+
return user
4052

41-
if user.is_deleted:
42-
raise HTTPException(status_code=400, detail="User deleted")
53+
raise credentials_exception
54+
55+
56+
async def get_optional_user(
57+
request: Request,
58+
db: AsyncSession = Depends(async_get_db)
59+
) -> User | None:
60+
token = request.headers.get("Authorization")
61+
if not token:
62+
return None
63+
64+
try:
65+
token_type, _, token_value = token.partition(' ')
66+
if token_type.lower() != 'bearer' or not token_value:
67+
return None
68+
69+
return await get_current_user(token_value, db)
70+
71+
except HTTPException as http_exc:
72+
if http_exc.status_code != 401:
73+
logger.error(f"Unexpected HTTPException in get_optional_user: {http_exc.detail}")
74+
return None
75+
76+
except Exception as exc:
77+
logger.error(f"Unexpected error in get_optional_user: {exc}")
78+
return None
4379

44-
return user
4580

4681
async def get_current_superuser(current_user: Annotated[User, Depends(get_current_user)]) -> User:
47-
if not current_user.is_superuser:
82+
if not current_user["is_superuser"]:
4883
raise privileges_exception
4984

5085
return current_user
86+
87+
88+
async def rate_limiter(
89+
request: Request,
90+
db: Annotated[AsyncSession, Depends(async_get_db)],
91+
user: User | None = Depends(get_optional_user)
92+
):
93+
path = sanitize_path(request.url.path)
94+
if user:
95+
user_id = user["id"]
96+
tier = await crud_tiers.get(db, id=user["tier_id"])
97+
if tier:
98+
rate_limit = await crud_rate_limits.get(
99+
db=db,
100+
tier_id=tier["id"],
101+
path=path
102+
)
103+
if rate_limit:
104+
limit, period = rate_limit["limit"], rate_limit["period"]
105+
else:
106+
logger.warning(f"User {user_id} with tier '{tier['name']}' has no specific rate limit for path '{path}'. Applying default rate limit.")
107+
limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD
108+
else:
109+
logger.warning(f"User {user_id} has no assigned tier. Applying default rate limit.")
110+
limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD
111+
else:
112+
user_id = request.client.host
113+
limit, period = DEFAULT_LIMIT, DEFAULT_PERIOD
114+
115+
is_limited = await is_rate_limited(
116+
db=db,
117+
user_id=user_id,
118+
path=path,
119+
limit=limit,
120+
period=period
121+
)
122+
if is_limited:
123+
raise HTTPException(
124+
status_code=429,
125+
detail="Rate limit exceeded"
126+
)

src/app/api/v1/login.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ async def login_for_access_token(
2929

3030
access_token_expires = timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
3131
access_token = await create_access_token(
32-
data={"sub": user.username}, expires_delta=access_token_expires
32+
data={"sub": user["username"]}, expires_delta=access_token_expires
3333
)
3434

3535
return {"access_token": access_token, "token_type": "bearer"}

src/app/api/v1/posts.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Annotated
1+
from typing import Annotated
22

33
from fastapi import Request, Depends, HTTPException
44
from sqlalchemy.ext.asyncio import AsyncSession
@@ -28,7 +28,7 @@ async def write_post(
2828
if db_user is None:
2929
raise HTTPException(status_code=404, detail="User not found")
3030

31-
if current_user.id != db_user["id"]:
31+
if current_user["id"] != db_user["id"]:
3232
raise privileges_exception
3333

3434
post_internal_dict = post.model_dump()
@@ -94,7 +94,7 @@ async def read_post(
9494
@cache(
9595
"{username}_post_cache",
9696
resource_id_name="id",
97-
to_invalidate_extra={"{username}_posts": "{username}"}
97+
pattern_to_invalidate_extra=["{username}_posts:*"]
9898
)
9999
async def patch_post(
100100
request: Request,
@@ -108,7 +108,7 @@ async def patch_post(
108108
if db_user is None:
109109
raise HTTPException(status_code=404, detail="User not found")
110110

111-
if current_user.id != db_user["id"]:
111+
if current_user["id"] != db_user["id"]:
112112
raise privileges_exception
113113

114114
db_post = await crud_posts.get(db=db, schema_to_select=PostRead, id=id, is_deleted=False)
@@ -136,7 +136,7 @@ async def erase_post(
136136
if db_user is None:
137137
raise HTTPException(status_code=404, detail="User not found")
138138

139-
if current_user.id != db_user.id:
139+
if current_user["id"] != db_user["id"]:
140140
raise privileges_exception
141141

142142
db_post = await crud_posts.get(db=db, schema_to_select=PostRead, id=id, is_deleted=False)

src/app/api/v1/rate_limits.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,28 @@
44
from sqlalchemy.ext.asyncio import AsyncSession
55
import fastapi
66

7+
from app.api.dependencies import get_current_superuser
8+
from app.api.paginated import PaginatedListResponse, paginated_response, compute_offset
9+
from app.core.database import async_get_db
10+
from app.crud.crud_rate_limit import crud_rate_limits
11+
from app.crud.crud_tier import crud_tiers
712
from app.schemas.rate_limit import (
813
RateLimitRead,
914
RateLimitCreate,
1015
RateLimitCreateInternal,
1116
RateLimitUpdate
1217
)
13-
from app.api.dependencies import get_current_superuser
14-
from app.core.database import async_get_db
15-
from app.crud.crud_rate_limit import crud_rate_limits
16-
from app.crud.crud_tier import crud_tiers
17-
from app.api.paginated import PaginatedListResponse, paginated_response, compute_offset
1818

1919
router = fastapi.APIRouter(tags=["rate_limits"])
2020

21-
@router.post("/tier/{name}/rate_limit", dependencies=[Depends(get_current_superuser)], status_code=201)
21+
@router.post("/tier/{tier_name}/rate_limit", dependencies=[Depends(get_current_superuser)], status_code=201)
2222
async def write_rate_limit(
2323
request: Request,
24-
name: str,
24+
tier_name: str,
2525
rate_limit: RateLimitCreate,
2626
db: Annotated[AsyncSession, Depends(async_get_db)]
2727
):
28-
db_tier = await crud_tiers.get(db=db, name=name)
28+
db_tier = await crud_tiers.get(db=db, name=tier_name)
2929
if not db_tier:
3030
raise HTTPException(status_code=404, detail="Tier not found")
3131

@@ -40,15 +40,15 @@ async def write_rate_limit(
4040
return await crud_rate_limits.create(db=db, object=rate_limit_internal)
4141

4242

43-
@router.get("/tier/{name}/rate_limits", response_model=PaginatedListResponse[RateLimitRead])
43+
@router.get("/tier/{tier_name}/rate_limits", response_model=PaginatedListResponse[RateLimitRead])
4444
async def read_rate_limits(
4545
request: Request,
46-
name: str,
46+
tier_name: str,
4747
db: Annotated[AsyncSession, Depends(async_get_db)],
4848
page: int = 1,
4949
items_per_page: int = 10
5050
):
51-
db_tier = await crud_tiers.get(db=db, name=name)
51+
db_tier = await crud_tiers.get(db=db, name=tier_name)
5252
if not db_tier:
5353
raise HTTPException(status_code=404, detail="Tier not found")
5454

@@ -90,7 +90,7 @@ async def read_rate_limit(
9090
return db_rate_limit
9191

9292

93-
@router.patch("/tier/{tier_name}/rate_limit/{path}", dependencies=[Depends(get_current_superuser)])
93+
@router.patch("/tier/{tier_name}/rate_limit/{id}", dependencies=[Depends(get_current_superuser)])
9494
async def patch_rate_limit(
9595
request: Request,
9696
tier_name: str,
@@ -99,23 +99,31 @@ async def patch_rate_limit(
9999
db: Annotated[AsyncSession, Depends(async_get_db)]
100100
):
101101
db_tier = await crud_tiers.get(db=db, name=tier_name)
102-
if not db_tier:
102+
if db_tier is None:
103103
raise HTTPException(status_code=404, detail="Tier not found")
104104

105105
db_rate_limit = await crud_rate_limits.get(
106-
db=db,
106+
db=db,
107107
schema_to_select=RateLimitRead,
108108
tier_id=db_tier["id"],
109109
id=id
110110
)
111111
if db_rate_limit is None:
112112
raise HTTPException(status_code=404, detail="Rate Limit not found")
113113

114+
db_rate_limit_path = await crud_rate_limits.exists(
115+
db=db,
116+
tier_id=db_tier["id"],
117+
path=values.path
118+
)
119+
120+
db_rate_limit_name = await crud_rate_limits.exists(db=db)
121+
114122
await crud_rate_limits.update(db=db, object=values, id=db_rate_limit["id"])
115123
return {"message": "Rate Limit updated"}
116124

117125

118-
@router.delete("/tier/{tier_name}/rate_limit/{path}", dependencies=[Depends(get_current_superuser)])
126+
@router.delete("/tier/{tier_name}/rate_limit/{id}", dependencies=[Depends(get_current_superuser)])
119127
async def erase_rate_limit(
120128
request: Request,
121129
tier_name: str,

src/app/api/v1/tasks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from arq.jobs import Job as ArqJob
2-
from fastapi import APIRouter, HTTPException
2+
from fastapi import APIRouter, Depends
33

44
from app.core import queue
55
from app.schemas.job import Job
6+
from app.api.dependencies import rate_limiter
67

78
router = APIRouter(prefix="/tasks", tags=["tasks"])
89

9-
@router.post("/task", response_model=Job, status_code=201)
10+
@router.post("/task", response_model=Job, status_code=201, dependencies=[Depends(rate_limiter)])
1011
async def create_task(message: str):
1112
job = await queue.pool.enqueue_job("sample_background_task", message)
1213
return {"id": job.job_id}

src/app/api/v1/tiers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Annotated
1+
from typing import Annotated
22

33
from fastapi import Request, Depends, HTTPException
44
from sqlalchemy.ext.asyncio import AsyncSession

0 commit comments

Comments
 (0)