Skip to content

Commit 0194146

Browse files
committed
Fix rate limiting
1 parent 22d6670 commit 0194146

File tree

4 files changed

+36
-48
lines changed

4 files changed

+36
-48
lines changed

backend/routes/account.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from security.audit import log_security
1919
from security.profanity import contains_profanity
2020
from security.user_agent_blocklist import is_user_agent_blocked
21-
from security.rate_limit import rate_limit_per_ip, rate_limit_per_user
21+
from security.rate_limit import rate_limit_per_ip
2222
router = APIRouter()
2323

2424
_FAILED_ATTEMPT_WINDOW_SECONDS = 300
@@ -445,7 +445,7 @@ def logout(
445445

446446

447447
@router.post("/change-password")
448-
@rate_limit_per_user("5/hour")
448+
@rate_limit_per_ip("5/hour")
449449
def change_password(
450450
request: Request,
451451
password_request: ChangePasswordRequest,
@@ -487,7 +487,8 @@ def change_password(
487487

488488

489489
@router.get("/users")
490-
def list_users(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
490+
@rate_limit_per_ip("30/minute") # Per-IP limit to prevent abuse
491+
def list_users(request: Request, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
491492
users = db.query(User).order_by(User.username.asc()).all()
492493
return {
493494
"users": [
@@ -497,13 +498,15 @@ def list_users(current_user: User = Depends(get_current_user), db: Session = Dep
497498

498499

499500
@router.get("/crypto/public-key/of/{user_id}")
500-
def get_public_key_of(user_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
501+
@rate_limit_per_ip("100/minute") # Per-IP limit to prevent abuse
502+
def get_public_key_of(request: Request, user_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
501503
row = db.query(CryptoPublicKey).filter(CryptoPublicKey.user_id == user_id).first()
502504
return {"publicKey": row.public_key_b64 if row else None}
503505

504506

505507
@router.get("/users/search")
506-
def search_users(q: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
508+
@rate_limit_per_ip("60/minute") # Per-IP limit to prevent abuse
509+
def search_users(request: Request, q: str, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
507510
if len(q.strip()) < 2:
508511
return {"users": []}
509512

backend/routes/messaging.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from better_profanity import profanity as _bp
2828
from security.audit import log_access, log_dm, log_public_chat, log_security
2929
from security.profanity import censor_text
30-
from security.rate_limit import rate_limit_per_user
30+
from security.rate_limit import rate_limit_per_ip
3131

3232
router = APIRouter()
3333
logger = logging.getLogger("uvicorn.error")
@@ -386,7 +386,7 @@ async def _send_message_internal(
386386

387387

388388
@router.post("/send_message")
389-
@rate_limit_per_user("30/minute")
389+
@rate_limit_per_ip("30/minute")
390390
async def send_message(
391391
request: Request,
392392
message_request: SendMessageRequest | None = None,
@@ -414,7 +414,8 @@ async def send_message(
414414

415415

416416
@router.get("/get_messages")
417-
async def get_messages(db: Session = Depends(get_db)):
417+
@rate_limit_per_ip("60/minute") # Per-IP limit to prevent abuse
418+
async def get_messages(request: Request, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
418419
messages = db.query(Message).order_by(Message.timestamp.asc()).all()
419420

420421
messages_data = []
@@ -428,7 +429,7 @@ async def get_messages(db: Session = Depends(get_db)):
428429

429430

430431
@router.post("/dm/send")
431-
@rate_limit_per_user("20/minute")
432+
@rate_limit_per_ip("20/minute")
432433
async def dm_send(
433434
request: Request,
434435
payload: dict | None = None,
@@ -578,15 +579,17 @@ def convert_envelopes(envs: list[DMEnvelope]):
578579
}
579580

580581
@router.get("/dm/fetch")
581-
async def dm_fetch(since: int | None = None, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
582+
@rate_limit_per_ip("60/minute") # Per-IP limit to prevent abuse
583+
async def dm_fetch(request: Request, since: int | None = None, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
582584
q = db.query(DMEnvelope).filter(DMEnvelope.recipient_id == current_user.id)
583585
if since:
584586
q = q.filter(DMEnvelope.id > since)
585587
return convert_envelopes(q.order_by(DMEnvelope.id.asc()).all())
586588

587589

588590
@router.get("/dm/history/{other_user_id}")
589-
async def dm_history(other_user_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
591+
@rate_limit_per_ip("60/minute") # Per-IP limit to prevent abuse
592+
async def dm_history(request: Request, other_user_id: int, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
590593
return convert_envelopes(
591594
db.query(DMEnvelope)
592595
.filter(
@@ -599,7 +602,8 @@ async def dm_history(other_user_id: int, current_user: User = Depends(get_curren
599602

600603

601604
@router.get("/dm/conversations")
602-
async def get_dm_conversations(current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
605+
@rate_limit_per_ip("60/minute") # Per-IP limit to prevent abuse
606+
async def get_dm_conversations(request: Request, current_user: User = Depends(get_current_user), db: Session = Depends(get_db)):
603607
# Get all DM conversations where current user is involved
604608
conversations_query = db.query(DMEnvelope).filter(
605609
(DMEnvelope.sender_id == current_user.id) | (DMEnvelope.recipient_id == current_user.id)
@@ -641,7 +645,7 @@ async def get_dm_conversations(current_user: User = Depends(get_current_user), d
641645

642646

643647
@router.put("/edit_message/{message_id}")
644-
@rate_limit_per_user("20/minute")
648+
@rate_limit_per_ip("20/minute")
645649
async def edit_message(
646650
request: Request,
647651
message_id: int,
@@ -718,7 +722,7 @@ async def delete_message(
718722

719723

720724
@router.post("/add_reaction")
721-
@rate_limit_per_user("50/minute")
725+
@rate_limit_per_ip("50/minute")
722726
async def add_reaction(
723727
request: Request,
724728
reaction_request: ReactionRequest,
@@ -787,7 +791,7 @@ async def add_reaction(
787791

788792

789793
@router.post("/dm/add_reaction")
790-
@rate_limit_per_user("50/minute")
794+
@rate_limit_per_ip("50/minute")
791795
async def add_dm_reaction(
792796
request: Request,
793797
reaction_request: DMReactionRequest,

backend/routes/profile.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .messaging import messagingManager
1818
from security.audit import log_security
1919
from security.profanity import contains_profanity
20-
from security.rate_limit import rate_limit_per_user
20+
from security.rate_limit import rate_limit_per_ip
2121

2222
router = APIRouter()
2323

@@ -41,7 +41,7 @@ class UpdateProfileRequest(BaseModel):
4141
os.makedirs(PROFILE_PICTURES_DIR, exist_ok=True)
4242

4343
@router.post("/upload-profile-picture")
44-
@rate_limit_per_user("10/minute")
44+
@rate_limit_per_ip("10/minute")
4545
async def upload_profile_picture(
4646
request: Request,
4747
profile_picture: UploadFile = File(...),
@@ -167,7 +167,7 @@ async def list_users(
167167
}
168168

169169
@router.put("/user/profile")
170-
@rate_limit_per_user("10/minute")
170+
@rate_limit_per_ip("10/minute")
171171
async def update_user_profile(
172172
request: Request,
173173
update_request: UpdateProfileRequest,
@@ -245,7 +245,7 @@ async def update_user_profile(
245245

246246

247247
@router.put("/user/bio")
248-
@rate_limit_per_user("10/minute")
248+
@rate_limit_per_ip("10/minute")
249249
async def update_user_bio(
250250
request: Request,
251251
bio_request: UpdateBioRequest,

backend/security/rate_limit.py

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,24 @@
44
from fastapi import Request
55
from slowapi import Limiter
66
from slowapi.util import get_remote_address
7-
from slowapi.errors import RateLimitExceeded
87

98
from utils import get_client_ip
109

10+
def get_ip_key(request: Request) -> str:
11+
"""Get rate limit key based on IP address."""
12+
return get_client_ip(request) or get_remote_address(request)
13+
1114
# Initialize limiter with IP-based key function
15+
# Note: We don't set default_limits to avoid affecting all users if one IP is attacked.
16+
# Each endpoint should have an explicit rate limit based on its sensitivity.
1217
limiter = Limiter(
13-
key_func=lambda request: get_client_ip(request) or get_remote_address(request),
14-
default_limits=["1000/hour"], # Global default limit
18+
key_func=get_ip_key,
19+
default_limits=[], # No global default - each endpoint must have explicit limits
1520
storage_uri="memory://", # In-memory storage (can be changed to Redis later)
1621
)
1722

1823

19-
def get_user_id_key(request: Request) -> str:
20-
"""Get rate limit key based on authenticated user ID."""
21-
user = getattr(getattr(request, "state", None), "current_user", None)
22-
if user and hasattr(user, "id"):
23-
return f"user:{user.id}"
24-
# Fallback to IP if not authenticated
25-
return get_client_ip(request) or get_remote_address(request)
26-
27-
28-
def get_ip_key(request: Request) -> str:
29-
"""Get rate limit key based on IP address."""
30-
return get_client_ip(request) or get_remote_address(request)
31-
32-
33-
# Rate limit decorators for different endpoint types
24+
# Rate limit decorator for IP-based limiting
3425
def rate_limit_per_ip(limit: str) -> Callable:
3526
"""Rate limit based on IP address."""
36-
return limiter.limit(limit, key_func=get_ip_key)
37-
38-
39-
def rate_limit_per_user(limit: str) -> Callable:
40-
"""Rate limit based on authenticated user ID, fallback to IP.
41-
42-
Note: The user must be authenticated (get_current_user dependency must run first).
43-
The user will be available in request.state.current_user after authentication.
44-
"""
45-
return limiter.limit(limit, key_func=get_user_id_key)
46-
27+
return limiter.limit(limit, key_func=get_ip_key)

0 commit comments

Comments
 (0)