Skip to content

Commit a1792e1

Browse files
njbrakeclaude
andauthored
fix(gateway): make CORS configurable, remove unsafe wildcard+credentials (#837)
## Description The gateway hardcodes CORS with `allow_origins=["*"]` and `allow_credentials=True`. Per the CORS spec, wildcard + credentials is invalid (browsers reject it). A wildcard origin is also overly permissive for production. This PR adds a `cors_allow_origins` config field (default: empty/disabled). When origins are listed, credentials are allowed. When a wildcard is used, credentials are auto-disabled per spec. CORS is off by default (secure-by-default). ## PR Type - 🐛 Bug Fix ## Checklist - [x] I understand the code I am submitting. - [x] I have added unit tests that prove my fix/feature works - [x] I have run this code locally and verified it fixes the issue. - [x] New and existing tests pass locally - [ ] Documentation was updated where necessary - [x] I have read and followed the [contribution guidelines](https://github.com/mozilla-ai/any-llm/blob/main/CONTRIBUTING.md) - **AI Usage:** - [ ] No AI was used. - [ ] AI was used for drafting/refactoring. - [x] This is fully AI-generated. ## AI Usage Information - AI Model used: Claude Opus 4.6 - AI Developer Tool used: Claude Code - Any other info you'd like to share: Identified during a comprehensive gateway code review. - [x] I am an AI Agent filling out this form (check box if true) Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent a35933d commit a1792e1

File tree

3 files changed

+99
-7
lines changed

3 files changed

+99
-7
lines changed

src/any_llm/gateway/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class GatewayConfig(BaseSettings):
3737
host: str = Field(default="0.0.0.0", description="Host to bind the server to") # noqa: S104
3838
port: int = Field(default=8000, description="Port to bind the server to")
3939
master_key: str | None = Field(default=None, description="Master key for protecting management endpoints")
40+
cors_allow_origins: list[str] = Field(
41+
default_factory=list, description="Allowed CORS origins (empty list disables CORS)"
42+
)
4043
providers: dict[str, dict[str, Any]] = Field(
4144
default_factory=dict, description="Pre-configured provider credentials"
4245
)

src/any_llm/gateway/server.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,15 @@ def create_app(config: GatewayConfig) -> FastAPI:
3434
version=__version__,
3535
)
3636

37-
app.add_middleware(
38-
CORSMiddleware,
39-
allow_origins=["*"],
40-
allow_credentials=True,
41-
allow_methods=["*"],
42-
allow_headers=["*"],
43-
)
37+
if config.cors_allow_origins:
38+
allow_credentials = "*" not in config.cors_allow_origins
39+
app.add_middleware(
40+
CORSMiddleware,
41+
allow_origins=config.cors_allow_origins,
42+
allow_credentials=allow_credentials,
43+
allow_methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
44+
allow_headers=["Content-Type", "Authorization", "X-AnyLLM-Key"],
45+
)
4446

4547
app.include_router(chat.router)
4648
app.include_router(keys.router)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Tests for CORS configuration."""
2+
3+
from typing import Any
4+
5+
from fastapi.testclient import TestClient
6+
from sqlalchemy.orm import Session
7+
8+
from any_llm.gateway.config import GatewayConfig
9+
from any_llm.gateway.db import get_db
10+
from any_llm.gateway.server import create_app
11+
12+
13+
def test_cors_disabled_by_default(postgres_url: str, test_db: Session) -> None:
14+
"""Test that CORS middleware is not added when cors_allow_origins is empty."""
15+
config = GatewayConfig(
16+
database_url=postgres_url,
17+
master_key="test-master-key",
18+
host="127.0.0.1",
19+
port=8000,
20+
)
21+
22+
app = create_app(config)
23+
24+
def override_get_db() -> Any:
25+
yield test_db
26+
27+
app.dependency_overrides[get_db] = override_get_db
28+
29+
with TestClient(app) as client:
30+
response = client.get("/health", headers={"Origin": "https://evil.com"})
31+
assert response.status_code == 200
32+
assert "access-control-allow-origin" not in response.headers
33+
34+
35+
def test_cors_with_specific_origins(postgres_url: str, test_db: Session) -> None:
36+
"""Test that CORS allows only configured origins."""
37+
config = GatewayConfig(
38+
database_url=postgres_url,
39+
master_key="test-master-key",
40+
host="127.0.0.1",
41+
port=8000,
42+
cors_allow_origins=["https://trusted.com"],
43+
)
44+
45+
app = create_app(config)
46+
47+
def override_get_db() -> Any:
48+
yield test_db
49+
50+
app.dependency_overrides[get_db] = override_get_db
51+
52+
with TestClient(app) as client:
53+
# Trusted origin should get CORS headers
54+
response = client.get("/health", headers={"Origin": "https://trusted.com"})
55+
assert response.status_code == 200
56+
assert response.headers.get("access-control-allow-origin") == "https://trusted.com"
57+
assert response.headers.get("access-control-allow-credentials") == "true"
58+
59+
# Untrusted origin should not get CORS headers
60+
response = client.get("/health", headers={"Origin": "https://evil.com"})
61+
assert response.status_code == 200
62+
assert response.headers.get("access-control-allow-origin") != "https://evil.com"
63+
64+
65+
def test_cors_wildcard_disables_credentials(postgres_url: str, test_db: Session) -> None:
66+
"""Test that wildcard origin disables allow_credentials per CORS spec."""
67+
config = GatewayConfig(
68+
database_url=postgres_url,
69+
master_key="test-master-key",
70+
host="127.0.0.1",
71+
port=8000,
72+
cors_allow_origins=["*"],
73+
)
74+
75+
app = create_app(config)
76+
77+
def override_get_db() -> Any:
78+
yield test_db
79+
80+
app.dependency_overrides[get_db] = override_get_db
81+
82+
with TestClient(app) as client:
83+
response = client.get("/health", headers={"Origin": "https://any-site.com"})
84+
assert response.status_code == 200
85+
assert response.headers.get("access-control-allow-origin") == "*"
86+
# Credentials should NOT be allowed with wildcard
87+
assert response.headers.get("access-control-allow-credentials") != "true"

0 commit comments

Comments
 (0)