Skip to content

Commit 018fb9f

Browse files
committed
Added Rate Limit to protect API Key balance
1 parent bcd0cdb commit 018fb9f

File tree

2 files changed

+336
-5
lines changed

2 files changed

+336
-5
lines changed

app.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from database import init_connection_pool, init_db, execute_query, execute_transaction
1313
from ai_agent_deepseek import ai_agent
1414
import time
15+
from functools import wraps
16+
from collections import defaultdict
1517

1618
# Load environment variables
1719
load_dotenv()
@@ -40,6 +42,130 @@
4042
# Hardcoded secret key (CWE-798)
4143
app.secret_key = "secret123"
4244

45+
# Rate limiting configuration
46+
RATE_LIMIT_WINDOW = 3 * 60 * 60 # 3 hours in seconds
47+
UNAUTHENTICATED_LIMIT = 5 # requests per IP per window
48+
AUTHENTICATED_LIMIT = 10 # requests per user per window
49+
50+
# In-memory rate limiting storage
51+
# Format: {key: [(timestamp, request_count), ...]}
52+
rate_limit_storage = defaultdict(list)
53+
54+
def cleanup_rate_limit_storage():
55+
"""Clean up old entries from rate limit storage"""
56+
current_time = time.time()
57+
cutoff_time = current_time - RATE_LIMIT_WINDOW
58+
59+
for key in list(rate_limit_storage.keys()):
60+
# Remove entries older than the rate limit window
61+
rate_limit_storage[key] = [
62+
(timestamp, count) for timestamp, count in rate_limit_storage[key]
63+
if timestamp > cutoff_time
64+
]
65+
# Remove empty entries
66+
if not rate_limit_storage[key]:
67+
del rate_limit_storage[key]
68+
69+
def get_client_ip():
70+
"""Get client IP address, considering proxy headers"""
71+
if request.headers.get('X-Forwarded-For'):
72+
return request.headers.get('X-Forwarded-For').split(',')[0].strip()
73+
elif request.headers.get('X-Real-IP'):
74+
return request.headers.get('X-Real-IP')
75+
else:
76+
return request.remote_addr
77+
78+
def check_rate_limit(key, limit):
79+
"""Check if the request should be rate limited"""
80+
cleanup_rate_limit_storage()
81+
current_time = time.time()
82+
83+
# Count requests in the current window
84+
request_count = sum(count for timestamp, count in rate_limit_storage[key] if timestamp > current_time - RATE_LIMIT_WINDOW)
85+
86+
if request_count >= limit:
87+
return False, request_count, limit
88+
89+
# Add current request
90+
rate_limit_storage[key].append((current_time, 1))
91+
return True, request_count + 1, limit
92+
93+
def ai_rate_limit(f):
94+
"""Rate limiting decorator for AI endpoints"""
95+
@wraps(f)
96+
def decorated_function(*args, **kwargs):
97+
client_ip = get_client_ip()
98+
99+
# Check if this is an authenticated request
100+
auth_header = request.headers.get('Authorization')
101+
if auth_header and auth_header.startswith('Bearer '):
102+
# Extract token and get user info
103+
token = auth_header.split(' ')[1]
104+
try:
105+
user_data = verify_token(token)
106+
if user_data:
107+
# Authenticated mode: rate limit by both user and IP
108+
user_key = f"ai_auth_user_{user_data['user_id']}"
109+
ip_key = f"ai_auth_ip_{client_ip}"
110+
111+
# Check user-based rate limit
112+
user_allowed, user_count, user_limit = check_rate_limit(user_key, AUTHENTICATED_LIMIT)
113+
if not user_allowed:
114+
return jsonify({
115+
'status': 'error',
116+
'message': f'Rate limit exceeded for user. You have made {user_count} requests in the last 3 hours. Limit is {user_limit} requests per 3 hours.',
117+
'rate_limit_info': {
118+
'limit_type': 'authenticated_user',
119+
'current_count': user_count,
120+
'limit': user_limit,
121+
'window_hours': 3,
122+
'user_id': user_data['user_id']
123+
}
124+
}), 429
125+
126+
# Check IP-based rate limit
127+
ip_allowed, ip_count, ip_limit = check_rate_limit(ip_key, AUTHENTICATED_LIMIT)
128+
if not ip_allowed:
129+
return jsonify({
130+
'status': 'error',
131+
'message': f'Rate limit exceeded for IP address. This IP has made {ip_count} requests in the last 3 hours. Limit is {ip_limit} requests per 3 hours.',
132+
'rate_limit_info': {
133+
'limit_type': 'authenticated_ip',
134+
'current_count': ip_count,
135+
'limit': ip_limit,
136+
'window_hours': 3,
137+
'client_ip': client_ip
138+
}
139+
}), 429
140+
141+
# Both checks passed, proceed with authenticated function
142+
return f(*args, **kwargs)
143+
except:
144+
pass # Fall through to unauthenticated handling
145+
146+
# Unauthenticated mode: rate limit by IP only
147+
ip_key = f"ai_unauth_ip_{client_ip}"
148+
ip_allowed, ip_count, ip_limit = check_rate_limit(ip_key, UNAUTHENTICATED_LIMIT)
149+
150+
if not ip_allowed:
151+
return jsonify({
152+
'status': 'error',
153+
'message': f'Rate limit exceeded. This IP address has made {ip_count} requests in the last 3 hours. Limit is {ip_limit} requests per 3 hours for unauthenticated users.',
154+
'rate_limit_info': {
155+
'limit_type': 'unauthenticated_ip',
156+
'current_count': ip_count,
157+
'limit': ip_limit,
158+
'window_hours': 3,
159+
'client_ip': client_ip,
160+
'suggestion': 'Log in to get higher rate limits (10 requests per 3 hours)'
161+
}
162+
}), 429
163+
164+
# Rate limit check passed, proceed with unauthenticated function
165+
return f(*args, **kwargs)
166+
167+
return decorated_function
168+
43169
UPLOAD_FOLDER = 'static/uploads'
44170
if not os.path.exists(UPLOAD_FOLDER):
45171
os.makedirs(UPLOAD_FOLDER)
@@ -1395,6 +1521,7 @@ def get_payment_history(current_user):
13951521

13961522
# AI CUSTOMER SUPPORT AGENT ROUTES (INTENTIONALLY VULNERABLE)
13971523
@app.route('/api/ai/chat', methods=['POST'])
1524+
@ai_rate_limit
13981525
@token_required
13991526
def ai_chat_authenticated(current_user):
14001527
"""
@@ -1466,6 +1593,7 @@ def ai_chat_authenticated(current_user):
14661593
}), 500
14671594

14681595
@app.route('/api/ai/chat/anonymous', methods=['POST'])
1596+
@ai_rate_limit
14691597
def ai_chat_anonymous():
14701598
"""
14711599
Anonymous AI chat endpoint (UNAUTHENTICATED MODE)
@@ -1504,6 +1632,7 @@ def ai_chat_anonymous():
15041632
}), 500
15051633

15061634
@app.route('/api/ai/system-info', methods=['GET'])
1635+
@ai_rate_limit
15071636
def ai_system_info():
15081637
"""
15091638
VULNERABILITY: Exposes AI system information without authentication
@@ -1541,6 +1670,77 @@ def ai_system_info():
15411670
'message': str(e)
15421671
}), 500
15431672

1673+
@app.route('/api/ai/rate-limit-status', methods=['GET'])
1674+
def ai_rate_limit_status():
1675+
"""
1676+
Check current rate limit status for AI endpoints
1677+
Useful for debugging and transparency
1678+
"""
1679+
try:
1680+
cleanup_rate_limit_storage()
1681+
client_ip = get_client_ip()
1682+
current_time = time.time()
1683+
1684+
status = {
1685+
'status': 'success',
1686+
'client_ip': client_ip,
1687+
'rate_limits': {
1688+
'unauthenticated': {
1689+
'limit': UNAUTHENTICATED_LIMIT,
1690+
'window_hours': 3,
1691+
'requests_made': 0
1692+
},
1693+
'authenticated': {
1694+
'limit': AUTHENTICATED_LIMIT,
1695+
'window_hours': 3,
1696+
'user_requests_made': 0,
1697+
'ip_requests_made': 0
1698+
}
1699+
}
1700+
}
1701+
1702+
# Check unauthenticated rate limit
1703+
unauth_key = f"ai_unauth_ip_{client_ip}"
1704+
unauth_count = sum(count for timestamp, count in rate_limit_storage[unauth_key]
1705+
if timestamp > current_time - RATE_LIMIT_WINDOW)
1706+
status['rate_limits']['unauthenticated']['requests_made'] = unauth_count
1707+
status['rate_limits']['unauthenticated']['remaining'] = max(0, UNAUTHENTICATED_LIMIT - unauth_count)
1708+
1709+
# Check if user is authenticated
1710+
auth_header = request.headers.get('Authorization')
1711+
if auth_header and auth_header.startswith('Bearer '):
1712+
token = auth_header.split(' ')[1]
1713+
try:
1714+
user_data = verify_token(token)
1715+
if user_data:
1716+
# Check authenticated rate limits
1717+
user_key = f"ai_auth_user_{user_data['user_id']}"
1718+
ip_key = f"ai_auth_ip_{client_ip}"
1719+
1720+
user_count = sum(count for timestamp, count in rate_limit_storage[user_key]
1721+
if timestamp > current_time - RATE_LIMIT_WINDOW)
1722+
ip_count = sum(count for timestamp, count in rate_limit_storage[ip_key]
1723+
if timestamp > current_time - RATE_LIMIT_WINDOW)
1724+
1725+
status['rate_limits']['authenticated']['user_requests_made'] = user_count
1726+
status['rate_limits']['authenticated']['ip_requests_made'] = ip_count
1727+
status['rate_limits']['authenticated']['user_remaining'] = max(0, AUTHENTICATED_LIMIT - user_count)
1728+
status['rate_limits']['authenticated']['ip_remaining'] = max(0, AUTHENTICATED_LIMIT - ip_count)
1729+
status['authenticated_user'] = {
1730+
'user_id': user_data['user_id'],
1731+
'username': user_data['username']
1732+
}
1733+
except:
1734+
pass # Token invalid, stay with unauthenticated status
1735+
1736+
return jsonify(status)
1737+
1738+
except Exception as e:
1739+
return jsonify({
1740+
'status': 'error',
1741+
'message': str(e)
1742+
}), 500
1743+
15441744
if __name__ == '__main__':
15451745
init_db()
15461746
init_auth_routes(app)

0 commit comments

Comments
 (0)