Skip to content

Commit f45890c

Browse files
committed
Fix
1 parent 585c9f1 commit f45890c

File tree

2 files changed

+44
-40
lines changed

2 files changed

+44
-40
lines changed

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ async def augment_oidc_spec(message: Message):
6868
if content_encoding:
6969
handler = ENCODING_HANDLERS.get(content_encoding)
7070
assert handler, f"Unsupported content encoding: {content_encoding}"
71-
body = handler.decompress(body)
71+
body = (
72+
handler.decompress(body)
73+
if content_encoding != "deflate"
74+
else handler.decompress(body, -zlib.MAX_WBITS)
75+
)
7276

7377
# Augment the spec
7478
body = dict_to_bytes(self.augment_spec(json.loads(body)))

tests/test_openapi.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tests for OpenAPI spec handling."""
22

3-
3+
import pytest
44
from fastapi import FastAPI
55
from fastapi.testclient import TestClient
66
from utils import AppFactory
@@ -41,7 +41,6 @@ def test_no_private_endpoints(source_api_server: str):
4141
assert "info" in openapi
4242
assert "openapi" in openapi
4343
assert "paths" in openapi
44-
# assert "oidcAuth" not in openapi.get("components", {}).get("securitySchemes", {})
4544

4645

4746
def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
@@ -60,43 +59,44 @@ def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
6059
assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})
6160

6261

63-
# def test_oidc_in_openapi_spec_compressed(source_api: FastAPI, source_api_server: str):
64-
# """When OpenAPI spec endpoint is set, the proxied OpenAPI spec is augmented with oidc details."""
65-
66-
# # Create a compressed response factory
67-
# def compressed_response_factory(request: Request):
68-
# assert False
69-
# # Get the original OpenAPI spec
70-
# openapi = source_api.openapi()
71-
# compressed_data = gzip.compress(str(openapi).encode())
72-
# return Response(
73-
# content=compressed_data,
74-
# headers={
75-
# "Content-Encoding": "gzip",
76-
# "Content-Type": "application/json",
77-
# },
78-
# )
79-
80-
# # Set the custom response factory
81-
# source_api.state.response_factory = compressed_response_factory
82-
83-
# app = app_factory(
84-
# upstream_url=source_api_server,
85-
# openapi_spec_endpoint=source_api.openapi_url,
86-
# )
87-
# client = TestClient(app)
88-
# response = client.get(
89-
# source_api.openapi_url,
90-
# headers={"Accept-Encoding": "gzip"},
91-
# )
92-
# assert response.status_code == 200
93-
# assert response.headers.get("content-encoding") == "gzip"
94-
95-
# openapi = response.json() # TestClient automatically decompresses
96-
# assert "info" in openapi
97-
# assert "openapi" in openapi
98-
# assert "paths" in openapi
99-
# assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})
62+
@pytest.mark.parametrize("compression_type", ["gzip", "br", "deflate"])
63+
def test_oidc_in_openapi_spec_compressed(
64+
source_api: FastAPI, source_api_server: str, compression_type: str
65+
):
66+
"""When OpenAPI spec endpoint is set, the proxied OpenAPI spec is augmented with oidc details."""
67+
app = app_factory(
68+
upstream_url=source_api_server,
69+
openapi_spec_endpoint=source_api.openapi_url,
70+
)
71+
client = TestClient(app)
72+
73+
# Test with gzip acceptance
74+
response = client.get(
75+
source_api.openapi_url, headers={"Accept-Encoding": compression_type}
76+
)
77+
assert response.status_code == 200
78+
assert response.headers.get("content-encoding") == compression_type
79+
assert response.headers.get("content-type") == "application/json"
80+
81+
# TestClient automatically decompresses
82+
openapi = response.json()
83+
assert "info" in openapi
84+
assert "openapi" in openapi
85+
assert "paths" in openapi
86+
assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})
87+
88+
# # Test without gzip acceptance
89+
# response = client.get(source_api.openapi_url)
90+
# assert response.status_code == 200
91+
# assert "content-encoding" not in response.headers
92+
# assert response.headers.get("content-type") == "application/json"
93+
94+
# # Should get same content
95+
# openapi = response.json()
96+
# assert "info" in openapi
97+
# assert "openapi" in openapi
98+
# assert "paths" in openapi
99+
# assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})
100100

101101

102102
def test_oidc_in_openapi_spec_private_endpoints(

0 commit comments

Comments
 (0)