Skip to content

Commit 1189f97

Browse files
committed
fix: prevent sending automatic accept-encoding headers upstream
1 parent faffd05 commit 1189f97

File tree

5 files changed

+79
-97
lines changed

5 files changed

+79
-97
lines changed

src/stac_auth_proxy/handlers/reverse_proxy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ async def proxy_request(self, request: Request) -> httpx.Response:
4444
headers=headers,
4545
content=request.stream(),
4646
)
47+
48+
# NOTE: HTTPX adds headers, so we need to trim them before sending request
49+
for h in rp_req.headers:
50+
if h not in headers:
51+
del rp_req.headers[h]
52+
4753
logger.debug(f"Proxying request to {rp_req.url}")
4854

4955
start_time = time.perf_counter()

tests/test_authn.py

Lines changed: 0 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -112,79 +112,3 @@ def test_scopes(
112112
)
113113
expected_status_code = 200 if expected_permitted else 401
114114
assert response.status_code == expected_status_code
115-
116-
117-
# @pytest.mark.parametrize(
118-
# "is_valid, path, method",
119-
# [
120-
# *[
121-
# [True, *endpoint_method]
122-
# for endpoint_method in [
123-
# ["/collections", "POST"],
124-
# ["/collections/foo", "PUT"],
125-
# ["/collections/foo", "PATCH"],
126-
# ["/collections/foo/items", "POST"],
127-
# ["/collections/foo/items/bar", "PUT"],
128-
# ["/collections/foo/items/bar", "PATCH"],
129-
# ]
130-
# ],
131-
# *[
132-
# [False, *endpoint_method]
133-
# for endpoint_method in [
134-
# ["/collections/foo", "DELETE"],
135-
# ["/collections/foo/items/bar", "DELETE"],
136-
# ]
137-
# ],
138-
# ],
139-
# )
140-
# def test_scopes(source_api_server, token_builder, is_valid, path, method):
141-
# """Private endpoints permit access with a valid token."""
142-
# test_app = app_factory(
143-
# upstream_url=source_api_server,
144-
# default_public=True,
145-
# private_endpoints={
146-
# r"^/collections$": [
147-
# ("POST", ["collections:create"]),
148-
# ],
149-
# r"^/collections/([^/]+)$": [
150-
# # ("PUT", ["collections:update"]),
151-
# # ("PATCH", ["collections:update"]),
152-
# ("DELETE", ["collections:delete"]),
153-
# ],
154-
# r"^/collections/([^/]+)/items$": [
155-
# ("POST", ["items:create"]),
156-
# ],
157-
# r"^/collections/([^/]+)/items/([^/]+)$": [
158-
# # ("PUT", ["items:update"]),
159-
# # ("PATCH", ["items:update"]),
160-
# ("DELETE", ["items:delete"]),
161-
# ],
162-
# r"^/collections/([^/]+)/bulk_items$": [
163-
# ("POST", ["items:create"]),
164-
# ],
165-
# },
166-
# )
167-
# valid_auth_token = token_builder(
168-
# {
169-
# "scopes": " ".join(
170-
# [
171-
# "collection:create",
172-
# "items:create",
173-
# "collections:update",
174-
# "items:update",
175-
# ]
176-
# )
177-
# }
178-
# )
179-
# client = TestClient(test_app)
180-
181-
# response = client.request(
182-
# method=method,
183-
# url=path,
184-
# headers={"Authorization": f"Bearer {valid_auth_token}"},
185-
# json={} if method != "DELETE" else None,
186-
# )
187-
# if is_valid:
188-
# assert response.status_code == 200
189-
# else:
190-
# assert response.status_code == 403

tests/test_filters_jinja2.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
"""Tests for Jinja2 CQL2 filter (simplified for readability)."""
22

33
import json
4-
from typing import cast
5-
from unittest.mock import MagicMock
64

75
import cql2
86
import pytest
97
from fastapi.testclient import TestClient
10-
from httpx import Request
11-
from utils import AppFactory, parse_query_string
8+
from utils import AppFactory, get_upstream_request
129

1310
FILTER_EXPR_CASES = [
1411
pytest.param(
@@ -148,14 +145,6 @@ def _build_client(
148145
return TestClient(app, headers=headers)
149146

150147

151-
async def _get_upstream_request(mock_upstream: MagicMock):
152-
"""Fetch the raw body and query params from the single upstream request."""
153-
assert mock_upstream.call_count == 1
154-
[request] = cast(list[Request], mock_upstream.call_args[0])
155-
req_body = request._streamed_body
156-
return req_body.decode(), parse_query_string(request.url.query.decode("utf-8"))
157-
158-
159148
@pytest.mark.parametrize(
160149
"filter_template_expr, expected_auth_filter, expected_anon_filter",
161150
FILTER_EXPR_CASES,
@@ -182,8 +171,8 @@ async def test_search_post(
182171
response.raise_for_status()
183172

184173
# Retrieve the JSON body that was actually sent upstream
185-
proxied_body_str = (await _get_upstream_request(mock_upstream))[0]
186-
proxied_body = json.loads(proxied_body_str)
174+
proxied_request = await get_upstream_request(mock_upstream)
175+
proxied_body = json.loads(proxied_request.body)
187176

188177
# Determine the expected combined filter
189178
proxy_filter = cql2.Expr(
@@ -231,8 +220,8 @@ async def test_search_get(
231220
response.raise_for_status()
232221

233222
# For GET, we expect the upstream body to be empty, but URL params to be appended
234-
proxied_body, upstream_query = await _get_upstream_request(mock_upstream)
235-
assert proxied_body == ""
223+
proxied_request = await get_upstream_request(mock_upstream)
224+
assert proxied_request.body == ""
236225

237226
# Determine the expected combined filter
238227
proxy_filter = cql2.Expr(
@@ -253,7 +242,7 @@ async def test_search_get(
253242
"filter-lang": filter_lang,
254243
}
255244
assert (
256-
upstream_query == expected_output
245+
proxied_request.query_params == expected_output
257246
), "GET query should combine filter expressions."
258247

259248

@@ -284,15 +273,15 @@ async def test_items_list(
284273
response.raise_for_status()
285274

286275
# For GET items, we also expect an empty body and appended querystring
287-
proxied_body, proxied_query = await _get_upstream_request(mock_upstream)
288-
assert proxied_body == ""
276+
proxied_request = await get_upstream_request(mock_upstream)
277+
assert proxied_request.body == ""
289278

290279
# Only the appended filter (no input_filter merges in these particular tests),
291280
# but you could do similar merging logic if needed.
292281
proxy_filter = cql2.Expr(
293282
expected_auth_filter if is_authenticated else expected_anon_filter
294283
)
295-
assert proxied_query == {
284+
assert proxied_request.query_params == {
296285
"filter-lang": "cql2-text",
297286
"filter": (
298287
proxy_filter + cql2.Expr(qs_filter)

tests/test_proxy.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
"""Test authentication cases for the proxy app."""
2+
3+
from fastapi.testclient import TestClient
4+
from utils import AppFactory, get_upstream_request
5+
6+
app_factory = AppFactory(
7+
oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration",
8+
default_public=True,
9+
public_endpoints={},
10+
private_endpoints={},
11+
)
12+
13+
14+
async def test_proxied_headers_no_encoding(source_api_server, mock_upstream):
15+
"""Clients that don't accept encoding should not receive it."""
16+
test_app = app_factory(upstream_url=source_api_server)
17+
18+
client = TestClient(test_app)
19+
req = client.build_request(method="GET", url="/", headers={})
20+
for h in req.headers:
21+
if h in ["accept-encoding"]:
22+
del req.headers[h]
23+
client.send(req)
24+
25+
proxied_request = await get_upstream_request(mock_upstream)
26+
assert "accept-encoding" not in proxied_request.headers
27+
28+
29+
async def test_proxied_headers_with_encoding(source_api_server, mock_upstream):
30+
"""Clients that do accept encoding should receive it."""
31+
test_app = app_factory(upstream_url=source_api_server)
32+
33+
client = TestClient(test_app)
34+
req = client.build_request(
35+
method="GET", url="/", headers={"accept-encoding": "gzip"}
36+
)
37+
client.send(req)
38+
39+
proxied_request = await get_upstream_request(mock_upstream)
40+
assert proxied_request.headers.get("accept-encoding") == "gzip"

tests/utils.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
import json
44
from dataclasses import dataclass
5-
from typing import Callable
5+
from typing import Callable, cast
6+
from unittest.mock import MagicMock
67
from urllib.parse import parse_qs, unquote
78

89
import httpx
10+
from httpx import Headers, Request
911

1012
from stac_auth_proxy import Settings, create_app
1113

@@ -68,3 +70,24 @@ def parse_query_string(qs: str) -> dict:
6870
result[key] = unquote(value)
6971

7072
return result
73+
74+
75+
async def get_upstream_request(mock_upstream: MagicMock) -> "UpstreamRequest":
76+
"""Fetch the raw body and query params from the single upstream request."""
77+
assert mock_upstream.call_count == 1
78+
[request] = cast(list[Request], mock_upstream.call_args[0])
79+
req_body = request._streamed_body
80+
return UpstreamRequest(
81+
body=req_body.decode(),
82+
query_params=parse_query_string(request.url.query.decode("utf-8")),
83+
headers=request.headers,
84+
)
85+
86+
87+
@dataclass
88+
class UpstreamRequest:
89+
"""The raw body and query params from the single upstream request."""
90+
91+
body: str
92+
query_params: dict
93+
headers: Headers

0 commit comments

Comments
 (0)