Skip to content

Commit 2fe0852

Browse files
authored
fix: handle compressed OpenAPI responses & ensure paths correctly marked private (#29)
1 parent 00fe51b commit 2fe0852

File tree

6 files changed

+261
-36
lines changed

6 files changed

+261
-36
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,9 @@ dev = [
4646
"pytest-asyncio>=0.25.1",
4747
"pytest-cov>=5.0.0",
4848
"pytest>=8.3.3",
49+
"starlette-cramjam>=0.4.0",
4950
]
5051

5152
[tool.pytest.ini_options]
5253
asyncio_default_fixture_loop_scope = "function"
53-
asyncio_mode = "auto"
54+
asyncio_mode = "auto"

src/stac_auth_proxy/app.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def create_app(settings: Optional[Settings] = None) -> FastAPI:
3838
OpenApiMiddleware,
3939
openapi_spec_path=settings.openapi_spec_endpoint,
4040
oidc_config_url=str(settings.oidc_discovery_url),
41+
public_endpoints=settings.public_endpoints,
4142
private_endpoints=settings.private_endpoints,
4243
default_public=settings.default_public,
4344
)
Lines changed: 81 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,26 @@
11
"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
22

3+
import gzip
34
import json
5+
import re
6+
import zlib
47
from dataclasses import dataclass
5-
from typing import Any
8+
from typing import Any, Optional
69

10+
import brotli
11+
from starlette.datastructures import MutableHeaders
712
from starlette.requests import Request
813
from starlette.types import ASGIApp, Message, Receive, Scope, Send
914

1015
from ..config import EndpointMethods
1116
from ..utils.requests import dict_to_bytes
1217

18+
ENCODING_HANDLERS = {
19+
"gzip": gzip,
20+
"deflate": zlib,
21+
"br": brotli,
22+
}
23+
1324

1425
@dataclass(frozen=True)
1526
class OpenApiMiddleware:
@@ -19,6 +30,7 @@ class OpenApiMiddleware:
1930
openapi_spec_path: str
2031
oidc_config_url: str
2132
private_endpoints: EndpointMethods
33+
public_endpoints: EndpointMethods
2234
default_public: bool
2335
oidc_auth_scheme_name: str = "oidcAuth"
2436

@@ -27,26 +39,63 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2739
if scope["type"] != "http" or Request(scope).url.path != self.openapi_spec_path:
2840
return await self.app(scope, receive, send)
2941

30-
total_body = b""
42+
start_message: Optional[Message] = None
43+
body = b""
3144

3245
async def augment_oidc_spec(message: Message):
33-
if message["type"] != "http.response.body":
46+
nonlocal start_message
47+
nonlocal body
48+
if message["type"] == "http.response.start":
49+
# NOTE: Because we are modifying the response body, we will need to update
50+
# the content-length header. However, headers are sent before we see the
51+
# body. To handle this, we delay sending the http.response.start message
52+
# until after we alter the body.
53+
start_message = message
54+
return
55+
elif message["type"] != "http.response.body":
3456
return await send(message)
3557

36-
# TODO: Make more robust to handle non-JSON responses
37-
38-
nonlocal total_body
39-
40-
total_body += message["body"]
58+
body += message["body"]
4159

42-
# Pass empty body chunks until all chunks have been received
60+
# Skip body chunks until all chunks have been received
4361
if message["more_body"]:
44-
return await send({**message, "body": b""})
45-
62+
return
63+
64+
# Maybe decompress the body
65+
headers = MutableHeaders(scope=start_message)
66+
content_encoding = headers.get("content-encoding", "").lower()
67+
handler = None
68+
if content_encoding:
69+
handler = ENCODING_HANDLERS.get(content_encoding)
70+
assert handler, f"Unsupported content encoding: {content_encoding}"
71+
body = (
72+
handler.decompress(body)
73+
if content_encoding != "deflate"
74+
else handler.decompress(body, -zlib.MAX_WBITS)
75+
)
76+
77+
# Augment the spec
78+
body = dict_to_bytes(self.augment_spec(json.loads(body)))
79+
80+
# Maybe re-compress the body
81+
if handler:
82+
body = handler.compress(body)
83+
84+
# Update the content-length header
85+
headers["content-length"] = str(len(body))
86+
assert start_message, "Expected start_message to be set"
87+
start_message["headers"] = [
88+
(key.encode(), value.encode()) for key, value in headers.items()
89+
]
90+
91+
# Send http.response.start
92+
await send(start_message)
93+
94+
# Send http.response.body
4695
await send(
4796
{
4897
"type": "http.response.body",
49-
"body": dict_to_bytes(self.augment_spec(json.loads(total_body))),
98+
"body": body,
5099
"more_body": False,
51100
}
52101
)
@@ -63,9 +112,24 @@ def augment_spec(self, openapi_spec) -> dict[str, Any]:
63112
}
64113
for path, method_config in openapi_spec["paths"].items():
65114
for method, config in method_config.items():
66-
for private_method in self.private_endpoints.get(path, []):
67-
if method.casefold() == private_method.casefold():
68-
config.setdefault("security", []).append(
69-
{self.oidc_auth_scheme_name: []}
70-
)
115+
requires_auth = (
116+
self.path_matches(path, method, self.private_endpoints)
117+
if self.default_public
118+
else not self.path_matches(path, method, self.public_endpoints)
119+
)
120+
if requires_auth:
121+
config.setdefault("security", []).append(
122+
{self.oidc_auth_scheme_name: []}
123+
)
71124
return openapi_spec
125+
126+
@staticmethod
127+
def path_matches(path: str, method: str, endpoints: EndpointMethods) -> bool:
128+
"""Check if the given path and method match any of the regex patterns and methods in the endpoints."""
129+
for pattern, endpoint_methods in endpoints.items():
130+
if not re.match(pattern, path):
131+
continue
132+
for endpoint_method in endpoint_methods:
133+
if method.casefold() == endpoint_method.casefold():
134+
return True
135+
return False

tests/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import uvicorn
1111
from fastapi import FastAPI
1212
from jwcrypto import jwk, jwt
13+
from starlette_cramjam.middleware import CompressionMiddleware
1314
from utils import single_chunk_async_stream_response
1415

1516

@@ -69,6 +70,8 @@ def source_api():
6970
"""Create upstream API for testing purposes."""
7071
app = FastAPI(docs_url="/api.html", openapi_url="/api")
7172

73+
app.add_middleware(CompressionMiddleware, minimum_size=0, compression_level=1)
74+
7275
for path, methods in {
7376
"/": [
7477
"GET",
@@ -118,7 +121,7 @@ def source_api():
118121
return app
119122

120123

121-
@pytest.fixture(scope="session")
124+
@pytest.fixture
122125
def source_api_server(source_api):
123126
"""Run the source API in a background thread."""
124127
host, port = "127.0.0.1", 9119

tests/test_openapi.py

Lines changed: 75 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for OpenAPI spec handling."""
22

3+
import pytest
34
from fastapi import FastAPI
45
from fastapi.testclient import TestClient
56
from utils import AppFactory
@@ -40,7 +41,6 @@ def test_no_private_endpoints(source_api_server: str):
4041
assert "info" in openapi
4142
assert "openapi" in openapi
4243
assert "paths" in openapi
43-
# assert "oidcAuth" not in openapi.get("components", {}).get("securitySchemes", {})
4444

4545

4646
def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
@@ -59,37 +59,95 @@ def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
5959
assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})
6060

6161

62+
@pytest.mark.parametrize("compression_type", ["gzip", "br", "deflate"])
63+
def test_oidc_in_openapi_spec_compressed(
64+
source_api: FastAPI, source_api_server: str, compression_type: str
65+
):
66+
"""When OpenAPI spec endpoint is set, the proxied OpenAPI spec is augmented with oidc details."""
67+
app = app_factory(
68+
upstream_url=source_api_server,
69+
openapi_spec_endpoint=source_api.openapi_url,
70+
)
71+
client = TestClient(app)
72+
73+
# Test with gzip acceptance
74+
response = client.get(
75+
source_api.openapi_url, headers={"Accept-Encoding": compression_type}
76+
)
77+
assert response.status_code == 200
78+
assert response.headers.get("content-encoding") == compression_type
79+
assert response.headers.get("content-type") == "application/json"
80+
assert response.json()
81+
82+
6283
def test_oidc_in_openapi_spec_private_endpoints(
6384
source_api: FastAPI, source_api_server: str
6485
):
6586
"""When OpenAPI spec endpoint is set & endpoints are marked private, those endpoints are marked private in the spec."""
6687
private_endpoints = {
6788
# https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods
89+
r"^/collections$": ["POST"],
90+
r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"],
91+
# https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
92+
r"^/collections/([^/]+)/items$": ["POST"],
93+
r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"],
94+
# https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
95+
r"^/collections/([^/]+)/bulk_items$": ["POST"],
96+
}
97+
app = app_factory(
98+
upstream_url=source_api_server,
99+
openapi_spec_endpoint=source_api.openapi_url,
100+
default_public=True,
101+
private_endpoints=private_endpoints,
102+
)
103+
client = TestClient(app)
104+
105+
openapi = client.get(source_api.openapi_url).raise_for_status().json()
106+
107+
expected_auth = {
68108
"/collections": ["POST"],
69109
"/collections/{collection_id}": ["PUT", "PATCH", "DELETE"],
70-
# https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
71110
"/collections/{collection_id}/items": ["POST"],
72111
"/collections/{collection_id}/items/{item_id}": ["PUT", "PATCH", "DELETE"],
73-
# https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
74112
"/collections/{collection_id}/bulk_items": ["POST"],
75113
}
114+
for path, method_config in openapi["paths"].items():
115+
for method, config in method_config.items():
116+
security = config.get("security")
117+
path_in_expected_auth = path in expected_auth
118+
method_in_expected_auth = any(
119+
method.casefold() == m.casefold() for m in expected_auth.get(path, [])
120+
)
121+
if security:
122+
assert path_in_expected_auth
123+
assert method_in_expected_auth
124+
else:
125+
assert not all([path_in_expected_auth, method_in_expected_auth])
126+
127+
128+
def test_oidc_in_openapi_spec_public_endpoints(
129+
source_api: FastAPI, source_api_server: str
130+
):
131+
"""When OpenAPI spec endpoint is set & endpoints are marked public, those endpoints are not marked private in the spec."""
132+
public = {r"^/queryables$": ["GET"], r"^/api": ["GET"]}
76133
app = app_factory(
77134
upstream_url=source_api_server,
78135
openapi_spec_endpoint=source_api.openapi_url,
79-
private_endpoints=private_endpoints,
136+
default_public=False,
137+
public_endpoints=public,
80138
)
81139
client = TestClient(app)
140+
82141
openapi = client.get(source_api.openapi_url).raise_for_status().json()
83-
for path, methods in private_endpoints.items():
84-
for method in methods:
85-
openapi_path = openapi["paths"].get(path)
86-
assert openapi_path, f"Path {path} not found in OpenAPI spec"
87-
openapi_path_method = openapi_path.get(method.lower())
88-
assert (
89-
openapi_path_method
90-
), f"Method {method.lower()!r} not found for path {path!r} in OpenAPI spec for path {path}"
91-
security = openapi_path_method.get("security")
92-
assert security, f"Security not found for {path!r} {method.lower()!r}"
93-
assert any(
94-
"oidcAuth" in s for s in security
95-
), f'No "oidcAuth" in security for {path!r} {method.lower()!r}'
142+
143+
expected_auth = {"/queryables": ["GET"]}
144+
for path, method_config in openapi["paths"].items():
145+
for method, config in method_config.items():
146+
security = config.get("security")
147+
if security:
148+
assert path not in expected_auth
149+
else:
150+
assert path in expected_auth
151+
assert any(
152+
method.casefold() == m.casefold() for m in expected_auth[path]
153+
)

0 commit comments

Comments
 (0)