Skip to content

Commit 74780f0

Browse files
authored
fix: retain proxy headers when behind proxy (#88)
When applying proxy headers to upstream requests, we should first attempt to retrieve those headers from the request. These requests may already be set if the STAC Auth Proxy is itself behind a reverse proxy (e.g., NGINX). Related to #86
1 parent daf5d09 commit 74780f0

File tree

2 files changed

+216
-101
lines changed

2 files changed

+216
-101
lines changed

src/stac_auth_proxy/handlers/reverse_proxy.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,18 @@ def _prepare_headers(self, request: Request) -> MutableHeaders:
3737
headers = MutableHeaders(request.headers)
3838
headers.setdefault("Via", f"1.1 {self.proxy_name}")
3939

40-
proxy_client = request.client.host if request.client else "unknown"
41-
proxy_proto = request.url.scheme
42-
proxy_host = request.url.netloc
43-
proxy_path = request.base_url.path
40+
proxy_client = headers.get(
41+
"X-Forwarded-For", request.client.host if request.client else "unknown"
42+
)
43+
proxy_proto = headers.get("X-Forwarded-Proto", request.url.scheme)
44+
proxy_host = headers.get("X-Forwarded-Host", request.url.netloc)
45+
proxy_path = headers.get("X-Forwarded-Path", request.base_url.path)
4446
headers.setdefault(
4547
"Forwarded",
4648
f"for={proxy_client};host={proxy_host};proto={proxy_proto};path={proxy_path}",
4749
)
50+
51+
# NOTE: This is useful if the upstream API does not support the Forwarded header
4852
if self.legacy_forwarded_headers:
4953
headers.setdefault("X-Forwarded-For", proxy_client)
5054
headers.setdefault("X-Forwarded-Host", proxy_host)

tests/test_reverse_proxy.py

Lines changed: 208 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
from stac_auth_proxy.handlers.reverse_proxy import ReverseProxyHandler
77

88

9-
@pytest.fixture
10-
def mock_request():
11-
"""Create a mock FastAPI request."""
12-
scope = {
9+
def create_request(scope_overrides=None, headers=None):
10+
"""Create a mock FastAPI request with custom scope and headers."""
11+
default_scope = {
1312
"type": "http",
1413
"method": "GET",
1514
"path": "/test",
@@ -19,7 +18,20 @@ def mock_request():
1918
(b"accept", b"application/json"),
2019
],
2120
}
22-
return Request(scope)
21+
22+
if scope_overrides:
23+
default_scope.update(scope_overrides)
24+
25+
if headers:
26+
default_scope["headers"] = headers
27+
28+
return Request(default_scope)
29+
30+
31+
@pytest.fixture
32+
def mock_request():
33+
"""Create a mock FastAPI request."""
34+
return create_request()
2335

2436

2537
@pytest.fixture
@@ -28,15 +40,33 @@ def reverse_proxy_handler():
2840
return ReverseProxyHandler(upstream="http://upstream-api.com")
2941

3042

43+
@pytest.mark.parametrize(
44+
"legacy_headers,override_host,proxy_name,expected_host,expected_via",
45+
[
46+
(False, True, "stac-auth-proxy", "upstream-api.com", "1.1 stac-auth-proxy"),
47+
(True, True, "stac-auth-proxy", "upstream-api.com", "1.1 stac-auth-proxy"),
48+
(False, False, "stac-auth-proxy", "localhost:8000", "1.1 stac-auth-proxy"),
49+
(False, True, "custom-proxy", "upstream-api.com", "1.1 custom-proxy"),
50+
],
51+
)
3152
@pytest.mark.asyncio
32-
async def test_basic_headers(mock_request, reverse_proxy_handler):
33-
"""Test that basic headers are properly set."""
34-
headers = reverse_proxy_handler._prepare_headers(mock_request)
53+
async def test_basic_headers(
54+
mock_request, legacy_headers, override_host, proxy_name, expected_host, expected_via
55+
):
56+
"""Test basic header functionality with various configurations."""
57+
handler = ReverseProxyHandler(
58+
upstream="http://upstream-api.com",
59+
legacy_forwarded_headers=legacy_headers,
60+
override_host=override_host,
61+
proxy_name=proxy_name,
62+
)
63+
headers = handler._prepare_headers(mock_request)
3564

3665
# Check standard headers
37-
assert headers["Host"] == "upstream-api.com"
66+
assert headers["Host"] == expected_host
3867
assert headers["User-Agent"] == "test-agent"
3968
assert headers["Accept"] == "application/json"
69+
assert headers["Via"] == expected_via
4070

4171
# Check modern forwarded header
4272
assert "Forwarded" in headers
@@ -46,60 +76,28 @@ async def test_basic_headers(mock_request, reverse_proxy_handler):
4676
assert "proto=http" in forwarded
4777
assert "path=/" in forwarded
4878

49-
# Check Via header
50-
assert headers["Via"] == "1.1 stac-auth-proxy"
51-
52-
# Legacy headers should not be present by default
53-
assert "X-Forwarded-For" not in headers
54-
assert "X-Forwarded-Host" not in headers
55-
assert "X-Forwarded-Proto" not in headers
56-
assert "X-Forwarded-Path" not in headers
57-
58-
59-
@pytest.mark.asyncio
60-
async def test_legacy_forwarded_headers(mock_request):
61-
"""Test that legacy X-Forwarded-* headers are set when enabled."""
62-
handler = ReverseProxyHandler(
63-
upstream="http://upstream-api.com", legacy_forwarded_headers=True
64-
)
65-
headers = handler._prepare_headers(mock_request)
66-
67-
# Check legacy headers
68-
assert headers["X-Forwarded-For"] == "unknown"
69-
assert headers["X-Forwarded-Host"] == "localhost:8000"
70-
assert headers["X-Forwarded-Proto"] == "http"
71-
assert headers["X-Forwarded-Path"] == "/"
72-
73-
# Modern Forwarded header should still be present
74-
assert "Forwarded" in headers
75-
76-
77-
@pytest.mark.asyncio
78-
async def test_override_host_disabled(mock_request):
79-
"""Test that host override can be disabled."""
80-
handler = ReverseProxyHandler(
81-
upstream="http://upstream-api.com", override_host=False
82-
)
83-
headers = handler._prepare_headers(mock_request)
84-
assert headers["Host"] == "localhost:8000"
85-
86-
87-
@pytest.mark.asyncio
88-
async def test_custom_proxy_name(mock_request):
89-
"""Test that custom proxy name is used in Via header."""
90-
handler = ReverseProxyHandler(
91-
upstream="http://upstream-api.com", proxy_name="custom-proxy"
92-
)
93-
headers = handler._prepare_headers(mock_request)
94-
assert headers["Via"] == "1.1 custom-proxy"
79+
# Check legacy headers based on configuration
80+
if legacy_headers:
81+
assert headers["X-Forwarded-For"] == "unknown"
82+
assert headers["X-Forwarded-Host"] == "localhost:8000"
83+
assert headers["X-Forwarded-Proto"] == "http"
84+
assert headers["X-Forwarded-Path"] == "/"
85+
else:
86+
assert "X-Forwarded-For" not in headers
87+
assert "X-Forwarded-Host" not in headers
88+
assert "X-Forwarded-Proto" not in headers
89+
assert "X-Forwarded-Path" not in headers
9590

9691

92+
@pytest.mark.parametrize("legacy_headers", [False, True])
9793
@pytest.mark.asyncio
98-
async def test_forwarded_headers_with_client(mock_request):
94+
async def test_forwarded_headers_with_client(mock_request, legacy_headers):
9995
"""Test forwarded headers when client information is available."""
10096
# Add client information to the request
10197
mock_request.scope["client"] = ("192.168.1.1", 12345)
102-
handler = ReverseProxyHandler(upstream="http://upstream-api.com")
98+
handler = ReverseProxyHandler(
99+
upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers
100+
)
103101
headers = handler._prepare_headers(mock_request)
104102

105103
# Check modern Forwarded header
@@ -109,56 +107,37 @@ async def test_forwarded_headers_with_client(mock_request):
109107
assert "proto=http" in forwarded
110108
assert "path=/" in forwarded
111109

112-
# Legacy headers should not be present by default
113-
assert "X-Forwarded-For" not in headers
114-
assert "X-Forwarded-Host" not in headers
115-
assert "X-Forwarded-Proto" not in headers
116-
assert "X-Forwarded-Path" not in headers
110+
# Check legacy headers based on configuration
111+
if legacy_headers:
112+
assert headers["X-Forwarded-For"] == "192.168.1.1"
113+
assert headers["X-Forwarded-Host"] == "localhost:8000"
114+
assert headers["X-Forwarded-Proto"] == "http"
115+
assert headers["X-Forwarded-Path"] == "/"
116+
else:
117+
assert "X-Forwarded-For" not in headers
118+
assert "X-Forwarded-Host" not in headers
119+
assert "X-Forwarded-Proto" not in headers
120+
assert "X-Forwarded-Path" not in headers
117121

118122

123+
@pytest.mark.parametrize("legacy_headers", [False, True])
119124
@pytest.mark.asyncio
120-
async def test_legacy_forwarded_headers_with_client(mock_request):
121-
"""Test legacy forwarded headers when client information is available."""
122-
mock_request.scope["client"] = ("192.168.1.1", 12345)
125+
async def test_https_proto(mock_request, legacy_headers):
126+
"""Test that protocol is set correctly for HTTPS."""
127+
mock_request.scope["scheme"] = "https"
123128
handler = ReverseProxyHandler(
124-
upstream="http://upstream-api.com", legacy_forwarded_headers=True
129+
upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers
125130
)
126131
headers = handler._prepare_headers(mock_request)
127132

128-
# Check legacy headers
129-
assert headers["X-Forwarded-For"] == "192.168.1.1"
130-
assert headers["X-Forwarded-Host"] == "localhost:8000"
131-
assert headers["X-Forwarded-Proto"] == "http"
132-
assert headers["X-Forwarded-Path"] == "/"
133-
134-
# Modern Forwarded header should still be present
135-
assert "Forwarded" in headers
136-
137-
138-
@pytest.mark.asyncio
139-
async def test_https_proto(mock_request):
140-
"""Test that X-Forwarded-Proto is set correctly for HTTPS."""
141-
mock_request.scope["scheme"] = "https"
142-
handler = ReverseProxyHandler(upstream="http://upstream-api.com")
143-
headers = handler._prepare_headers(mock_request)
144-
145133
# Check modern Forwarded header
146134
assert "proto=https" in headers["Forwarded"]
147135

148-
# Legacy headers should not be present by default
149-
assert "X-Forwarded-Proto" not in headers
150-
151-
152-
@pytest.mark.asyncio
153-
async def test_https_proto_legacy(mock_request):
154-
"""Test that X-Forwarded-Proto is set correctly for HTTPS with legacy headers."""
155-
mock_request.scope["scheme"] = "https"
156-
handler = ReverseProxyHandler(
157-
upstream="http://upstream-api.com", legacy_forwarded_headers=True
158-
)
159-
headers = handler._prepare_headers(mock_request)
160-
assert headers["X-Forwarded-Proto"] == "https"
161-
assert "proto=https" in headers["Forwarded"]
136+
# Check legacy headers based on configuration
137+
if legacy_headers:
138+
assert headers["X-Forwarded-Proto"] == "https"
139+
else:
140+
assert "X-Forwarded-Proto" not in headers
162141

163142

164143
@pytest.mark.asyncio
@@ -171,3 +150,135 @@ async def test_non_standard_port(mock_request):
171150
handler = ReverseProxyHandler(upstream="http://upstream-api.com:8080")
172151
headers = handler._prepare_headers(mock_request)
173152
assert headers["Host"] == "upstream-api.com:8080"
153+
154+
155+
@pytest.mark.parametrize("legacy_headers", [False, True])
156+
@pytest.mark.asyncio
157+
async def test_nginx_proxy_headers_preserved(legacy_headers):
158+
"""Test that existing proxy headers from NGINX are preserved."""
159+
# Simulate a request that already has proxy headers set by NGINX
160+
headers = [
161+
(b"host", b"localhost:8000"),
162+
(b"user-agent", b"test-agent"),
163+
(b"x-forwarded-for", b"203.0.113.1, 198.51.100.1"),
164+
(b"x-forwarded-proto", b"https"),
165+
(b"x-forwarded-host", b"api.example.com"),
166+
(b"x-forwarded-path", b"/api/v1"),
167+
]
168+
request = create_request(headers=headers)
169+
handler = ReverseProxyHandler(
170+
upstream="http://upstream-api.com", legacy_forwarded_headers=legacy_headers
171+
)
172+
headers = handler._prepare_headers(request)
173+
174+
# Check that the existing proxy headers are preserved in the Forwarded header
175+
forwarded = headers["Forwarded"]
176+
assert "for=203.0.113.1, 198.51.100.1" in forwarded
177+
assert "host=api.example.com" in forwarded
178+
assert "proto=https" in forwarded
179+
assert "path=/api/v1" in forwarded
180+
181+
# The original headers should still be present (they're preserved from the request)
182+
assert headers["X-Forwarded-For"] == "203.0.113.1, 198.51.100.1"
183+
assert headers["X-Forwarded-Host"] == "api.example.com"
184+
assert headers["X-Forwarded-Proto"] == "https"
185+
assert headers["X-Forwarded-Path"] == "/api/v1"
186+
187+
188+
@pytest.mark.parametrize(
189+
"scope_overrides,headers,expected_forwarded",
190+
[
191+
pytest.param(
192+
{},
193+
[
194+
(b"host", b"localhost:8000"),
195+
(b"user-agent", b"test-agent"),
196+
(b"x-forwarded-for", b"203.0.113.1"),
197+
(b"x-forwarded-proto", b"https"),
198+
# Missing X-Forwarded-Host and X-Forwarded-Path
199+
],
200+
{
201+
"for": "203.0.113.1", # From existing header
202+
"host": "localhost:8000", # Fallback to request host
203+
"proto": "https", # From existing header
204+
"path": "/", # Fallback to request path
205+
},
206+
id="partial_headers_fallback",
207+
),
208+
pytest.param(
209+
{"client": ("192.168.1.1", 12345)}, # This should be ignored
210+
[
211+
(b"host", b"localhost:8000"),
212+
(b"user-agent", b"test-agent"),
213+
(b"x-forwarded-for", b"203.0.113.1, 198.51.100.1"),
214+
],
215+
{
216+
"for": "203.0.113.1, 198.51.100.1", # From existing header
217+
"host": "localhost:8000",
218+
"proto": "http",
219+
"path": "/",
220+
},
221+
id="client_info_precedence",
222+
),
223+
pytest.param(
224+
{"scheme": "https"}, # This should be ignored
225+
[
226+
(b"host", b"localhost:8000"),
227+
(b"user-agent", b"test-agent"),
228+
(b"x-forwarded-proto", b"http"), # NGINX says it's HTTP
229+
],
230+
{
231+
"for": "unknown",
232+
"host": "localhost:8000",
233+
"proto": "http", # From existing header
234+
"path": "/",
235+
},
236+
id="scheme_precedence",
237+
),
238+
pytest.param(
239+
{"path": "/custom/path"},
240+
[
241+
(b"host", b"localhost:8000"),
242+
(b"user-agent", b"test-agent"),
243+
(b"x-forwarded-path", b"/api/v1/root"), # NGINX says different path
244+
],
245+
{
246+
"for": "unknown",
247+
"host": "localhost:8000",
248+
"proto": "http",
249+
"path": "/api/v1/root", # From existing header
250+
},
251+
id="path_precedence",
252+
),
253+
pytest.param(
254+
{},
255+
[
256+
(b"host", b"localhost:8000"),
257+
(b"user-agent", b"test-agent"),
258+
(b"X-Forwarded-For", b"203.0.113.1"), # Mixed case
259+
(b"x-forwarded-proto", b"https"), # Lower case
260+
(b"X-FORWARDED-HOST", b"api.example.com"), # Upper case
261+
],
262+
{
263+
"for": "203.0.113.1",
264+
"host": "api.example.com",
265+
"proto": "https",
266+
"path": "/",
267+
},
268+
id="case_insensitive",
269+
),
270+
],
271+
)
272+
@pytest.mark.asyncio
273+
async def test_nginx_headers_behavior(scope_overrides, headers, expected_forwarded):
274+
"""Test various NGINX header behaviors and precedence rules."""
275+
request = create_request(scope_overrides=scope_overrides, headers=headers)
276+
handler = ReverseProxyHandler(upstream="http://upstream-api.com")
277+
result_headers = handler._prepare_headers(request)
278+
279+
# Check that the Forwarded header contains expected values
280+
forwarded = result_headers["Forwarded"]
281+
for key, expected_value in expected_forwarded.items():
282+
assert (
283+
f"{key}={expected_value}" in forwarded
284+
), f"Expected {key}={expected_value} in {forwarded}"

0 commit comments

Comments
 (0)