Skip to content

Commit 9231fa3

Browse files
committed
In progress
1 parent 00fe51b commit 9231fa3

File tree

3 files changed

+85
-35
lines changed

3 files changed

+85
-35
lines changed

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
"""Middleware to add auth information to the OpenAPI spec served by upstream API."""
22

3+
import brotli
4+
import gzip
35
import json
6+
import zlib
47
from dataclasses import dataclass
58
from typing import Any
69

@@ -43,6 +46,21 @@ async def augment_oidc_spec(message: Message):
4346
if message["more_body"]:
4447
return await send({**message, "body": b""})
4548

49+
# Handle compressed responses
50+
# content_encoding = (
51+
# message.get("headers", {})
52+
# .get(b"content-encoding", b"")
53+
# .decode()
54+
# .lower()
55+
# )
56+
# if content_encoding:
57+
# if "gzip" in content_encoding:
58+
# total_body = gzip.decompress(total_body)
59+
# elif "deflate" in content_encoding:
60+
# total_body = zlib.decompress(total_body)
61+
# elif "br" in content_encoding:
62+
# total_body = brotli.decompress(total_body)
63+
# print(f"{message=}")
4664
await send(
4765
{
4866
"type": "http.response.body",

tests/conftest.py

Lines changed: 26 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import pytest
1010
import uvicorn
11-
from fastapi import FastAPI
11+
from fastapi import FastAPI, Request
1212
from jwcrypto import jwk, jwt
1313
from utils import single_chunk_async_stream_response
1414

@@ -64,61 +64,52 @@ def build_token(payload: dict[str, Any], key=None) -> str:
6464
return build_token
6565

6666

67-
@pytest.fixture(scope="session")
67+
@pytest.fixture
6868
def source_api():
6969
"""Create upstream API for testing purposes."""
7070
app = FastAPI(docs_url="/api.html", openapi_url="/api")
7171

72+
def default_response_factory(request: Request):
73+
"""Default response factory."""
74+
return {"id": f"Response from {request.method}@{request.url.path}"}
75+
76+
# Store response factory in app state
77+
app.state.response_factory = default_response_factory
78+
7279
for path, methods in {
73-
"/": [
74-
"GET",
75-
],
76-
"/conformance": [
77-
"GET",
78-
],
79-
"/queryables": [
80-
"GET",
81-
],
82-
"/search": [
83-
"GET",
84-
"POST",
85-
],
86-
"/collections": [
87-
"GET",
88-
"POST",
89-
],
90-
"/collections/{collection_id}": [
91-
"GET",
92-
"PUT",
93-
"PATCH",
94-
"DELETE",
95-
],
96-
"/collections/{collection_id}/items": [
97-
"GET",
98-
"POST",
99-
],
80+
"/": ["GET"],
81+
"/conformance": ["GET"],
82+
"/queryables": ["GET"],
83+
"/search": ["GET", "POST"],
84+
"/collections": ["GET", "POST"],
85+
"/collections/{collection_id}": ["GET", "PUT", "PATCH", "DELETE"],
86+
"/collections/{collection_id}/items": ["GET", "POST"],
10087
"/collections/{collection_id}/items/{item_id}": [
10188
"GET",
10289
"PUT",
10390
"PATCH",
10491
"DELETE",
10592
],
106-
"/collections/{collection_id}/bulk_items": [
107-
"POST",
108-
],
93+
"/collections/{collection_id}/bulk_items": ["POST"],
10994
}.items():
11095
for method in methods:
111-
# NOTE: declare routes per method separately to avoid warning of "Duplicate Operation ID ... for function <lambda>"
96+
97+
def endpoint(request: Request):
98+
print(f"endpoint: {request}")
99+
return app.state.response_factory(request)
100+
101+
# NOTE: declare routes per method separately to avoid warning of "Duplicate Operation ID"
112102
app.add_api_route(
113103
path,
114-
lambda: {"id": f"Response from {method}@{path}"},
104+
endpoint,
105+
# lambda: {"id": f"Response from {method}@{path}"},
115106
methods=[method],
116107
)
117108

118109
return app
119110

120111

121-
@pytest.fixture(scope="session")
112+
@pytest.fixture
122113
def source_api_server(source_api):
123114
"""Run the source API in a background thread."""
124115
host, port = "127.0.0.1", 9119

tests/test_openapi.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from fastapi import FastAPI
44
from fastapi.testclient import TestClient
55
from utils import AppFactory
6+
import gzip
7+
from fastapi import Request, Response
68

79
app_factory = AppFactory(
810
oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration"
@@ -59,6 +61,45 @@ def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
5961
assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})
6062

6163

64+
def test_oidc_in_openapi_spec_compressed(source_api: FastAPI, source_api_server: str):
65+
"""When OpenAPI spec endpoint is set, the proxied OpenAPI spec is augmented with oidc details."""
66+
67+
# Create a compressed response factory
68+
def compressed_response_factory(request: Request):
69+
assert False
70+
# Get the original OpenAPI spec
71+
openapi = source_api.openapi()
72+
compressed_data = gzip.compress(str(openapi).encode())
73+
return Response(
74+
content=compressed_data,
75+
headers={
76+
"Content-Encoding": "gzip",
77+
"Content-Type": "application/json",
78+
},
79+
)
80+
81+
# Set the custom response factory
82+
source_api.state.response_factory = compressed_response_factory
83+
84+
app = app_factory(
85+
upstream_url=source_api_server,
86+
openapi_spec_endpoint=source_api.openapi_url,
87+
)
88+
client = TestClient(app)
89+
response = client.get(
90+
source_api.openapi_url,
91+
headers={"Accept-Encoding": "gzip"},
92+
)
93+
assert response.status_code == 200
94+
assert response.headers.get("content-encoding") == "gzip"
95+
96+
openapi = response.json() # TestClient automatically decompresses
97+
assert "info" in openapi
98+
assert "openapi" in openapi
99+
assert "paths" in openapi
100+
assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})
101+
102+
62103
def test_oidc_in_openapi_spec_private_endpoints(
63104
source_api: FastAPI, source_api_server: str
64105
):

0 commit comments

Comments
 (0)