Skip to content

Commit a6a06cc

Browse files
StuMasonclaude
andcommitted
feat: Add rate limit headers to API responses
Adds X-RateLimit-* headers to responses for authenticated requests: - X-RateLimit-Limit: Max requests per hour (1000 default) - X-RateLimit-Remaining: Remaining requests in current window - X-RateLimit-Reset: Unix timestamp when window resets This allows API clients to monitor their rate limit status and implement backoff strategies before hitting 429 errors. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent d76e104 commit a6a06cc

File tree

3 files changed

+79
-1
lines changed

3 files changed

+79
-1
lines changed

src/polar_flow_server/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from polar_flow_server.api import api_routers
2121
from polar_flow_server.core.config import settings
2222
from polar_flow_server.core.database import close_database, engine, init_database
23+
from polar_flow_server.middleware import RateLimitHeadersMiddleware
2324
from polar_flow_server.routes import root_redirect
2425

2526
# Configure structured logging
@@ -120,7 +121,7 @@ def create_app() -> Litestar:
120121
),
121122
),
122123
],
123-
middleware=[session_config.middleware],
124+
middleware=[session_config.middleware, RateLimitHeadersMiddleware],
124125
csrf_config=csrf_config,
125126
stores={"session_store": session_store},
126127
debug=settings.log_level == "DEBUG",
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Middleware modules."""
2+
3+
from polar_flow_server.middleware.rate_limit import RateLimitHeadersMiddleware
4+
5+
__all__ = ["RateLimitHeadersMiddleware"]
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
"""Rate limit response headers.
2+
3+
Adds X-RateLimit-* headers to responses for authenticated requests.
4+
"""
5+
6+
from litestar import Response
7+
from litestar.connection import ASGIConnection
8+
from litestar.types import Message, Receive, Scope, Send
9+
10+
from polar_flow_server.core.auth import RATE_LIMIT_STATE_KEY
11+
12+
13+
def add_rate_limit_headers(response: Response, connection: ASGIConnection) -> Response:
14+
"""Add rate limit headers to the response.
15+
16+
This is an after_request hook that reads rate limit info from
17+
connection state (set by per_user_api_key_guard) and adds
18+
X-RateLimit-* headers to the response.
19+
20+
Headers added:
21+
- X-RateLimit-Limit: Max requests per hour
22+
- X-RateLimit-Remaining: Remaining requests in current window
23+
- X-RateLimit-Reset: Unix timestamp when window resets
24+
"""
25+
rate_limit_info = connection.state.get(RATE_LIMIT_STATE_KEY)
26+
27+
if rate_limit_info:
28+
response.headers["X-RateLimit-Limit"] = str(rate_limit_info["limit"])
29+
response.headers["X-RateLimit-Remaining"] = str(rate_limit_info["remaining"])
30+
response.headers["X-RateLimit-Reset"] = str(rate_limit_info["reset"])
31+
32+
return response
33+
34+
35+
class RateLimitHeadersMiddleware:
36+
"""Middleware to add rate limit headers to responses.
37+
38+
Note: This middleware captures the state after the app processes the request,
39+
allowing it to access state set by guards during request processing.
40+
"""
41+
42+
def __init__(self, app: "ASGIApp") -> None: # noqa: F821
43+
"""Initialize middleware with the ASGI app."""
44+
self.app = app
45+
46+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
47+
"""Process the request and add rate limit headers to response."""
48+
if scope["type"] != "http":
49+
await self.app(scope, receive, send)
50+
return
51+
52+
# Initialize state dict if not present
53+
if "state" not in scope:
54+
scope["state"] = {}
55+
56+
async def send_wrapper(message: Message) -> None:
57+
"""Wrap send to inject rate limit headers."""
58+
if message["type"] == "http.response.start":
59+
rate_limit_info = scope.get("state", {}).get(RATE_LIMIT_STATE_KEY)
60+
if rate_limit_info:
61+
headers = list(message.get("headers", []))
62+
headers.extend(
63+
[
64+
(b"x-ratelimit-limit", str(rate_limit_info["limit"]).encode()),
65+
(b"x-ratelimit-remaining", str(rate_limit_info["remaining"]).encode()),
66+
(b"x-ratelimit-reset", str(rate_limit_info["reset"]).encode()),
67+
]
68+
)
69+
message = {**message, "headers": headers}
70+
await send(message)
71+
72+
await self.app(scope, receive, send_wrapper)

0 commit comments

Comments
 (0)