Skip to content

Commit 32f91bd

Browse files
authored
feat(server): add global rate limits (#1686)
Signed-off-by: Radek Ježek <radek.jezek@ibm.com>
1 parent 4e8da1c commit 32f91bd

File tree

31 files changed

+950
-190
lines changed

31 files changed

+950
-190
lines changed

.github/workflows/integration-test.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ jobs:
2323
integration-test:
2424
runs-on: ubuntu-latest
2525
steps:
26+
- name: Maximize build space
27+
uses: easimon/maximize-build-space@master
28+
with:
29+
root-reserve-mb: 15360
30+
temp-reserve-mb: 2048
31+
swap-size-mb: 1024
32+
remove-dotnet: 'true'
2633
- uses: actions/checkout@v4
2734
- name: "Set up Lima"
2835
uses: lima-vm/lima-actions/setup@v1

apps/agentstack-server/.vscode/launch.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
{
55
"name": "agentstack-server",
66
"type": "debugpy",
7+
"justMyCode": false,
78
"request": "launch",
89
"module": "uvicorn",
910
"args": [
@@ -13,6 +14,20 @@
1314
"--timeout-keep-alive=60",
1415
"--timeout-graceful-shutdown=2"
1516
],
17+
},
18+
{
19+
"name": "Python: Debug Tests",
20+
"type": "debugpy",
21+
"request": "launch",
22+
"program": "${file}",
23+
"purpose": [
24+
"debug-test"
25+
],
26+
"console": "integratedTerminal",
27+
"justMyCode": false,
28+
"presentation": {
29+
"hidden": true, // keep original launch order in 'run and debug' tab
30+
}
1631
}
1732
]
1833
}

apps/agentstack-server/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dependencies = [
4747
"mcp>=1.13.1",
4848
"opentelemetry-instrumentation-httpx>=0.59b0",
4949
"opentelemetry-instrumentation-fastapi>=0.59b0",
50+
"limits[async-redis]>=5.3.0",
5051
]
5152

5253
[dependency-groups]
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import ipaddress
5+
6+
from starlette.types import ASGIApp, Receive, Scope, Send
7+
8+
9+
class ProxyHeadersMiddleware:
10+
"""
11+
Modified https://github.com/Kludex/uvicorn/blob/main/uvicorn/middleware/proxy_headers.py
12+
Removed "for"
13+
Added "host" support
14+
"""
15+
16+
def __init__(self, app: ASGIApp, trusted_hosts: list[str] | str = "127.0.0.1") -> None:
17+
self.app = app
18+
self.trusted_hosts = _TrustedHosts(trusted_hosts)
19+
20+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
21+
if scope["type"] == "lifespan":
22+
return await self.app(scope, receive, send)
23+
24+
client_addr = scope.get("client")
25+
client_host = client_addr[0] if client_addr else None
26+
27+
if client_host in self.trusted_hosts:
28+
headers = dict(scope["headers"])
29+
30+
proto = None
31+
if b"x-forwarded-proto" in headers:
32+
proto = headers[b"x-forwarded-proto"].decode("latin1").strip()
33+
34+
host = None
35+
if b"x-forwarded-host" in headers:
36+
host = headers[b"x-forwarded-host"].decode("latin1").strip()
37+
38+
# X-Forwarded-For: client, proxy1, proxy2
39+
client_ip = None
40+
if b"x-forwarded-for" in headers:
41+
client_ip = headers[b"x-forwarded-for"].decode("latin1").split(",")[0].strip()
42+
43+
if b"forwarded" in headers:
44+
for forwarded in headers[b"forwarded"].decode("latin1").split(","):
45+
directives = dict([(val.strip() for val in seg.split("=")) for seg in forwarded.split(";")])
46+
if "proto" in directives or "host" in directives or "for" in directives:
47+
proto = directives.get("proto")
48+
host = directives.get("host")
49+
if "for" in directives:
50+
client_ip = directives.get("for", "").strip('"[]') or None
51+
break
52+
53+
if proto in {"http", "https", "ws", "wss"}:
54+
if scope["type"] == "websocket":
55+
scope["scheme"] = proto.replace("http", "ws")
56+
else:
57+
scope["scheme"] = proto
58+
59+
if host:
60+
scope["headers"] = [
61+
(key, value) if key != b"host" else (b"host", host.encode()) for key, value in scope["headers"]
62+
]
63+
scope["server"] = (host, None)
64+
65+
if client_ip:
66+
scope["client"] = (client_ip, 0)
67+
68+
return await self.app(scope, receive, send)
69+
70+
71+
def _parse_raw_hosts(value: str) -> list[str]:
72+
return [item.strip() for item in value.split(",")]
73+
74+
75+
class _TrustedHosts:
76+
"""Container for trusted hosts and networks"""
77+
78+
def __init__(self, trusted_hosts: list[str] | str) -> None:
79+
self.always_trust: bool = trusted_hosts in ("*", ["*"])
80+
81+
self.trusted_literals: set[str] = set()
82+
self.trusted_hosts: set[ipaddress.IPv4Address | ipaddress.IPv6Address] = set()
83+
self.trusted_networks: set[ipaddress.IPv4Network | ipaddress.IPv6Network] = set()
84+
85+
# Notes:
86+
# - We separate hosts from literals as there are many ways to write
87+
# an IPv6 Address so we need to compare by object.
88+
# - We don't convert IP Address to single host networks (e.g. /32 / 128) as
89+
# it more efficient to do an address lookup in a set than check for
90+
# membership in each network.
91+
# - We still allow literals as it might be possible that we receive a
92+
# something that isn't an IP Address e.g. a unix socket.
93+
94+
if not self.always_trust:
95+
if isinstance(trusted_hosts, str):
96+
trusted_hosts = _parse_raw_hosts(trusted_hosts)
97+
98+
for host in trusted_hosts:
99+
# Note: because we always convert invalid IP types to literals it
100+
# is not possible for the user to know they provided a malformed IP
101+
# type - this may lead to unexpected / difficult to debug behaviour.
102+
103+
if "/" in host:
104+
# Looks like a network
105+
try:
106+
self.trusted_networks.add(ipaddress.ip_network(host))
107+
except ValueError:
108+
# Was not a valid IP Network
109+
self.trusted_literals.add(host)
110+
else:
111+
try:
112+
self.trusted_hosts.add(ipaddress.ip_address(host))
113+
except ValueError:
114+
# Was not a valid IP Address
115+
self.trusted_literals.add(host)
116+
117+
def __contains__(self, host: str | None) -> bool:
118+
if self.always_trust:
119+
return True
120+
121+
if not host:
122+
return False
123+
124+
try:
125+
ip = ipaddress.ip_address(host)
126+
if ip in self.trusted_hosts:
127+
return True
128+
return any(ip in net for net in self.trusted_networks)
129+
130+
except ValueError:
131+
return host in self.trusted_literals
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2025 © BeeAI a Series of LF Projects, LLC
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import hashlib
5+
import logging
6+
import time
7+
from typing import Final, override
8+
9+
from fastapi import Request, Response, status
10+
from fastapi.responses import JSONResponse
11+
from limits import RateLimitItem
12+
from limits.aio.storage import Storage
13+
from limits.aio.strategies import STRATEGIES, RateLimiter
14+
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
15+
from starlette.types import ASGIApp
16+
17+
from agentstack_server.configuration import RateLimitConfiguration
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
class RateLimitMiddleware(BaseHTTPMiddleware):
23+
"""
24+
Rate limiting middleware that uses the limits library.
25+
26+
Supports both Redis and in-memory storage backends.
27+
Rate limit keys are generated based on authentication type:
28+
- Bearer tokens (OAuth/JWT): hashes the token
29+
- Basic auth: hashes the credentials
30+
- No auth: uses client IP address
31+
"""
32+
33+
def __init__(
34+
self,
35+
app: ASGIApp,
36+
limiter_storage: Storage,
37+
configuration: RateLimitConfiguration,
38+
):
39+
super().__init__(app)
40+
self.enabled: Final[bool] = configuration.enabled
41+
self.limits: Final[list[RateLimitItem]] = sorted(configuration.limits_parsed)
42+
self.limiter: Final[RateLimiter] = STRATEGIES[configuration.strategy](limiter_storage)
43+
44+
logger.info(
45+
"Rate limiting initialized\n:"
46+
+ f" Storage class: {type(limiter_storage).__name__}\n"
47+
+ f" Strategy class: {type(self.limiter).__name__}\n"
48+
+ f" Limits: {[str(limit) for limit in self.limits]}"
49+
)
50+
51+
def _hash_secret(self, secret: str) -> str:
52+
return hashlib.sha256(secret.encode()).hexdigest()
53+
54+
def _extract_auth_key(self, request: Request) -> str:
55+
"""
56+
Extract authentication key from request for rate limiting.
57+
58+
Priority:
59+
1. Bearer token (OAuth/JWT or internal JWT)
60+
2. Basic auth credentials (hashed)
61+
3. Client IP address
62+
"""
63+
# Check for Bearer token
64+
auth_header = request.headers.get("authorization", "")
65+
if auth_header.startswith("Bearer "):
66+
token = auth_header[7:] # Remove "Bearer " prefix
67+
return f"bearer:{self._hash_secret(token)}"
68+
69+
# Check for Basic auth
70+
if auth_header.startswith("Basic "):
71+
credentials = auth_header[6:] # Remove "Basic " prefix
72+
return f"basic:{self._hash_secret(credentials)}"
73+
74+
# Fallback to client IP
75+
client_host = request.client.host if request.client else "unknown"
76+
return f"ip:{client_host}"
77+
78+
@override
79+
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
80+
"""Process request with rate limiting."""
81+
if not self.enabled or not self.limits or request.url.path == "/healthcheck":
82+
return await call_next(request)
83+
84+
# Generate rate limit key
85+
rate_limit_key = self._extract_auth_key(request)
86+
87+
response: Response
88+
89+
# Check all configured limits
90+
header_limit = self.limits[0] # return the first limit which should be the shortest time period
91+
92+
for limit in self.limits:
93+
if not await self.limiter.hit(limit, rate_limit_key):
94+
logger.warning(
95+
f"Rate limit exceeded for key '{rate_limit_key[:20]}...' "
96+
+ f"on {request.method} {request.url.path} (limit: {limit})"
97+
)
98+
99+
header_limit = limit
100+
response = JSONResponse(
101+
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
102+
content={"error": "rate_limit_exceeded", "detail": f"Rate limit exceeded: {limit}"},
103+
)
104+
break
105+
else:
106+
response = await call_next(request)
107+
108+
reset_time, remaining = await self.limiter.get_window_stats(header_limit, rate_limit_key)
109+
110+
if existing_retry_after_header := response.headers.get("Retry-After"):
111+
try:
112+
retry_after = int(existing_retry_after_header)
113+
retry_after_timestamp = time.time() + retry_after
114+
reset_time = max(reset_time, retry_after_timestamp)
115+
except ValueError:
116+
logger.warning(f"Invalid Retry-After header value: {existing_retry_after_header}")
117+
118+
response.headers["X-RateLimit-Limit"] = str(header_limit.amount)
119+
response.headers["X-RateLimit-Remaining"] = str(remaining)
120+
response.headers["X-RateLimit-Reset"] = str(reset_time)
121+
response.headers["Retry-After"] = str(int(reset_time - time.time()))
122+
123+
return response

apps/agentstack-server/src/agentstack_server/application.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import logging
55
import time
66
from collections.abc import Iterable
7-
from contextlib import asynccontextmanager, suppress
7+
from contextlib import asynccontextmanager, nullcontext, suppress
88
from importlib.metadata import PackageNotFoundError, version
99

1010
import procrastinate
@@ -14,10 +14,13 @@
1414
from fastapi.openapi.utils import get_openapi
1515
from fastapi.responses import JSONResponse, ORJSONResponse
1616
from kink import Container, di, inject
17+
from limits.aio.storage import Storage
1718
from opentelemetry.metrics import CallbackOptions, Observation, get_meter
1819
from procrastinate.exceptions import AlreadyEnqueued
1920
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_500_INTERNAL_SERVER_ERROR
2021

22+
from agentstack_server.api.middleware.proxy_headers import ProxyHeadersMiddleware
23+
from agentstack_server.api.middleware.rate_limit import RateLimitMiddleware
2124
from agentstack_server.api.routes.a2a import router as a2a_router
2225
from agentstack_server.api.routes.auth import well_known_router as auth_well_known_router
2326
from agentstack_server.api.routes.configurations import router as configuration_router
@@ -44,7 +47,6 @@
4447
from agentstack_server.run_workers import run_workers
4548
from agentstack_server.service_layer.services.mcp import McpService
4649
from agentstack_server.telemetry import INSTRUMENTATION_NAME, shutdown_telemetry
47-
from agentstack_server.utils.fastapi import ProxyHeadersMiddleware
4850

4951
logger = logging.getLogger(__name__)
5052

@@ -177,7 +179,7 @@ def scrape_platform_status(options: CallbackOptions) -> Iterable[Observation]:
177179
# meter.create_observable_gauge("providers_by_status", callbacks=[scrape_providers_by_status])
178180

179181

180-
def app(*, dependency_overrides: Container | None = None) -> FastAPI:
182+
def app(*, dependency_overrides: Container | None = None, enable_workers: bool = True) -> FastAPI:
181183
"""Entrypoint for API application, called by Uvicorn"""
182184

183185
logger.info("Bootstrapping dependencies...")
@@ -189,7 +191,11 @@ def app(*, dependency_overrides: Container | None = None) -> FastAPI:
189191
async def lifespan(_app: FastAPI, procrastinate_app: procrastinate.App, mcp_service: McpService):
190192
try:
191193
register_telemetry()
192-
async with procrastinate_app.open_async(), run_workers(app=procrastinate_app), mcp_service:
194+
async with (
195+
procrastinate_app.open_async(),
196+
run_workers(app=procrastinate_app) if enable_workers else nullcontext(),
197+
mcp_service,
198+
):
193199
with suppress(AlreadyEnqueued):
194200
# Force initial sync of the registry immediately
195201
await check_registry.defer_async(timestamp=int(time.time()))
@@ -212,7 +218,9 @@ async def lifespan(_app: FastAPI, procrastinate_app: procrastinate.App, mcp_serv
212218
logger.info("Mounting routes...")
213219
mount_routes(app)
214220

221+
# Execution order is important here: https://fastapi.tiangolo.com/tutorial/middleware/#multiple-middleware-execution-order
222+
app.add_middleware(RateLimitMiddleware, limiter_storage=di[Storage], configuration=configuration.rate_limit)
215223
app.add_middleware(ProxyHeadersMiddleware, trusted_hosts="*" if configuration.trust_proxy_headers else "")
216-
register_global_exception_handlers(app)
217224

225+
register_global_exception_handlers(app)
218226
return app

0 commit comments

Comments
 (0)