Skip to content

Commit 3c45287

Browse files
committed
Make test fail for private endpoints
1 parent 0b7032b commit 3c45287

File tree

1 file changed

+94
-16
lines changed

1 file changed

+94
-16
lines changed

tests/test_openapi.py

Lines changed: 94 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from fastapi import FastAPI
44
from fastapi.testclient import TestClient
55
from utils import AppFactory
6+
import gzip
7+
from fastapi import Request, Response
68

79
app_factory = AppFactory(
810
oidc_discovery_url="https://example-stac-api.com/.well-known/openid-configuration"
@@ -59,37 +61,113 @@ def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
5961
assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})
6062

6163

64+
# def test_oidc_in_openapi_spec_compressed(source_api: FastAPI, source_api_server: str):
65+
# """When OpenAPI spec endpoint is set, the proxied OpenAPI spec is augmented with oidc details."""
66+
67+
# # Create a compressed response factory
68+
# def compressed_response_factory(request: Request):
69+
# assert False
70+
# # Get the original OpenAPI spec
71+
# openapi = source_api.openapi()
72+
# compressed_data = gzip.compress(str(openapi).encode())
73+
# return Response(
74+
# content=compressed_data,
75+
# headers={
76+
# "Content-Encoding": "gzip",
77+
# "Content-Type": "application/json",
78+
# },
79+
# )
80+
81+
# # Set the custom response factory
82+
# source_api.state.response_factory = compressed_response_factory
83+
84+
# app = app_factory(
85+
# upstream_url=source_api_server,
86+
# openapi_spec_endpoint=source_api.openapi_url,
87+
# )
88+
# client = TestClient(app)
89+
# response = client.get(
90+
# source_api.openapi_url,
91+
# headers={"Accept-Encoding": "gzip"},
92+
# )
93+
# assert response.status_code == 200
94+
# assert response.headers.get("content-encoding") == "gzip"
95+
96+
# openapi = response.json() # TestClient automatically decompresses
97+
# assert "info" in openapi
98+
# assert "openapi" in openapi
99+
# assert "paths" in openapi
100+
# assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})
101+
102+
62103
def test_oidc_in_openapi_spec_private_endpoints(
63104
source_api: FastAPI, source_api_server: str
64105
):
65106
"""When OpenAPI spec endpoint is set & endpoints are marked private, those endpoints are marked private in the spec."""
66107
private_endpoints = {
67108
# https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods
109+
r"^/collections$": ["POST"],
110+
r"^/collections/([^/]+)$": ["PUT", "PATCH", "DELETE"],
111+
# https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
112+
r"^/collections/([^/]+)/items$": ["POST"],
113+
r"^/collections/([^/]+)/items/([^/]+)$": ["PUT", "PATCH", "DELETE"],
114+
# https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
115+
r"^/collections/([^/]+)/bulk_items$": ["POST"],
116+
}
117+
app = app_factory(
118+
upstream_url=source_api_server,
119+
openapi_spec_endpoint=source_api.openapi_url,
120+
default_public=True,
121+
private_endpoints=private_endpoints,
122+
)
123+
client = TestClient(app)
124+
125+
openapi = client.get(source_api.openapi_url).raise_for_status().json()
126+
127+
expected_auth = {
68128
"/collections": ["POST"],
69129
"/collections/{collection_id}": ["PUT", "PATCH", "DELETE"],
70-
# https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
71130
"/collections/{collection_id}/items": ["POST"],
72131
"/collections/{collection_id}/items/{item_id}": ["PUT", "PATCH", "DELETE"],
73-
# https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
74132
"/collections/{collection_id}/bulk_items": ["POST"],
75133
}
134+
for path, method_config in openapi["paths"].items():
135+
for method, config in method_config.items():
136+
security = config.get("security")
137+
path_in_expected_auth = path in expected_auth
138+
method_in_expected_auth = any(
139+
method.casefold() == m.casefold() for m in expected_auth.get(path, [])
140+
)
141+
if security:
142+
assert path_in_expected_auth
143+
assert method_in_expected_auth
144+
else:
145+
assert not all([path_in_expected_auth, method_in_expected_auth])
146+
147+
148+
def test_oidc_in_openapi_spec_public_endpoints(
149+
source_api: FastAPI, source_api_server: str
150+
):
151+
"""When OpenAPI spec endpoint is set & endpoints are marked public, those endpoints are not marked private in the spec."""
152+
public = {r"^/queryables$": ["GET"], r"^/api": ["GET"]}
76153
app = app_factory(
77154
upstream_url=source_api_server,
78155
openapi_spec_endpoint=source_api.openapi_url,
79-
private_endpoints=private_endpoints,
156+
default_public=False,
157+
public_endpoints=public,
80158
)
81159
client = TestClient(app)
160+
82161
openapi = client.get(source_api.openapi_url).raise_for_status().json()
83-
for path, methods in private_endpoints.items():
84-
for method in methods:
85-
openapi_path = openapi["paths"].get(path)
86-
assert openapi_path, f"Path {path} not found in OpenAPI spec"
87-
openapi_path_method = openapi_path.get(method.lower())
88-
assert (
89-
openapi_path_method
90-
), f"Method {method.lower()!r} not found for path {path!r} in OpenAPI spec for path {path}"
91-
security = openapi_path_method.get("security")
92-
assert security, f"Security not found for {path!r} {method.lower()!r}"
93-
assert any(
94-
"oidcAuth" in s for s in security
95-
), f'No "oidcAuth" in security for {path!r} {method.lower()!r}'
162+
163+
expected_auth = {"/queryables": ["GET"]}
164+
for path, method_config in openapi["paths"].items():
165+
for method, config in method_config.items():
166+
security = config.get("security")
167+
if security:
168+
assert path not in expected_auth
169+
else:
170+
assert path in expected_auth
171+
assert any(
172+
method.casefold() == m.casefold() for m in expected_auth[path]
173+
)

0 commit comments

Comments
 (0)