Skip to content

Commit 5ff51ca

Browse files
committed
fix: handle deeply nested security dependencies (#14)
* Add failing test * bugfix: handle deeply nested security dependencies * bugfix: prevent env pollution * expand tests for complete coverage
1 parent 44116b7 commit 5ff51ca

File tree

4 files changed

+124
-23
lines changed

4 files changed

+124
-23
lines changed

src/stac_auth_proxy/handlers/open_api_spec.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fastapi import Request, Response
77
from fastapi.routing import APIRoute
88

9-
from ..utils import safe_headers
9+
from ..utils import has_any_security_requirements, safe_headers
1010
from .reverse_proxy import ReverseProxyHandler
1111

1212
logger = logging.getLogger(__name__)
@@ -28,26 +28,29 @@ async def dispatch(self, req: Request, res: Response):
2828
# Pass along the response headers
2929
res.headers.update(safe_headers(oidc_spec_response.headers))
3030

31-
# Add the OIDC security scheme to the components
32-
openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[
33-
self.auth_scheme_name
34-
] = {
35-
"type": "openIdConnect",
36-
"openIdConnectUrl": self.oidc_config_url,
37-
}
38-
3931
proxy_auth_routes = [
4032
r
4133
for r in req.app.routes
4234
# Ignore non-APIRoutes (we can't check their security dependencies)
4335
if isinstance(r, APIRoute)
4436
# Ignore routes that don't have security requirements
45-
and (
46-
r.dependant.security_requirements
47-
or any(d.security_requirements for d in r.dependant.dependencies)
48-
)
37+
and has_any_security_requirements(r.dependant)
4938
]
5039

40+
if not proxy_auth_routes:
41+
logger.warning(
42+
"No routes with security requirements found. OIDC security requirements will not be added."
43+
)
44+
return openapi_spec
45+
46+
# Add the OIDC security scheme to the components
47+
openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[
48+
self.auth_scheme_name
49+
] = {
50+
"type": "openIdConnect",
51+
"openIdConnectUrl": self.oidc_config_url,
52+
}
53+
5154
# Update the paths with the specified security requirements
5255
for path, method_config in openapi_spec["paths"].items():
5356
for method, config in method_config.items():
@@ -59,7 +62,7 @@ async def dispatch(self, req: Request, res: Response):
5962
continue
6063
# Add the OIDC security requirement
6164
config.setdefault("security", []).append(
62-
[{self.auth_scheme_name: []}]
65+
{self.auth_scheme_name: []}
6366
)
6467
break
6568

src/stac_auth_proxy/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
from urllib.parse import urlparse
55

6+
from fastapi.dependencies.models import Dependant
67
from httpx import Headers
78

89

@@ -29,3 +30,15 @@ def extract_variables(url: str) -> dict:
2930
pattern = r"^/collections/(?P<collection_id>[^/]+)(?:/(?:items|bulk_items)(?:/(?P<item_id>[^/]+))?)?/?$"
3031
match = re.match(pattern, path)
3132
return {k: v for k, v in match.groupdict().items() if v} if match else {}
33+
34+
35+
def has_any_security_requirements(dependency: Dependant) -> bool:
36+
"""
37+
Recursively check if any dependency within the hierarchy has a non-empty
38+
security_requirements list.
39+
"""
40+
if dependency.security_requirements:
41+
return True
42+
return any(
43+
has_any_security_requirements(sub_dep) for sub_dep in dependency.dependencies
44+
)

tests/conftest.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Pytest fixtures."""
22

33
import json
4+
import os
45
import threading
56
from typing import Any
67
from unittest.mock import MagicMock, patch
@@ -68,19 +69,42 @@ def source_api():
6869
app = FastAPI(docs_url="/api.html", openapi_url="/api")
6970

7071
for path, methods in {
71-
"/": ["GET"],
72-
"/conformance": ["GET"],
73-
"/queryables": ["GET"],
74-
"/search": ["GET", "POST"],
75-
"/collections": ["GET", "POST"],
76-
"/collections/{collection_id}": ["GET", "PUT", "DELETE"],
77-
"/collections/{collection_id}/items": ["GET", "POST"],
72+
"/": [
73+
"GET",
74+
],
75+
"/conformance": [
76+
"GET",
77+
],
78+
"/queryables": [
79+
"GET",
80+
],
81+
"/search": [
82+
"GET",
83+
"POST",
84+
],
85+
"/collections": [
86+
"GET",
87+
"POST",
88+
],
89+
"/collections/{collection_id}": [
90+
"GET",
91+
"PUT",
92+
"PATCH",
93+
"DELETE",
94+
],
95+
"/collections/{collection_id}/items": [
96+
"GET",
97+
"POST",
98+
],
7899
"/collections/{collection_id}/items/{item_id}": [
79100
"GET",
80101
"PUT",
102+
"PATCH",
81103
"DELETE",
82104
],
83-
"/collections/{collection_id}/bulk_items": ["POST"],
105+
"/collections/{collection_id}/bulk_items": [
106+
"POST",
107+
],
84108
}.items():
85109
for method in methods:
86110
# NOTE: declare routes per method separately to avoid warning of "Duplicate Operation ID ... for function <lambda>"
@@ -109,3 +133,10 @@ def source_api_server(source_api):
109133
yield f"http://{host}:{port}"
110134
server.should_exit = True
111135
thread.join()
136+
137+
138+
@pytest.fixture(autouse=True, scope="module")
139+
def mock_env():
140+
"""Clear environment variables to avoid poluting configs from runtime env."""
141+
with patch.dict(os.environ, clear=True):
142+
yield

tests/test_openapi.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
)
1010

1111

12-
def test_no_edit_openapi_spec(source_api_server):
12+
def test_no_openapi_spec_endpoint(source_api_server):
1313
"""When no OpenAPI spec endpoint is set, the proxied OpenAPI spec is unaltered."""
1414
app = app_factory(
1515
upstream_url=source_api_server,
@@ -25,6 +25,24 @@ def test_no_edit_openapi_spec(source_api_server):
2525
assert "oidcAuth" not in openapi.get("components", {}).get("securitySchemes", {})
2626

2727

28+
def test_no_private_endpoints(source_api_server):
29+
"""When no endpoints are private, the proxied OpenAPI spec is unaltered."""
30+
app = app_factory(
31+
upstream_url=source_api_server,
32+
openapi_spec_endpoint="/api",
33+
private_endpoints={},
34+
default_public=True,
35+
)
36+
client = TestClient(app)
37+
response = client.get("/api")
38+
assert response.status_code == 200
39+
openapi = response.json()
40+
assert "info" in openapi
41+
assert "openapi" in openapi
42+
assert "paths" in openapi
43+
assert "oidcAuth" not in openapi.get("components", {}).get("securitySchemes", {})
44+
45+
2846
def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
2947
"""When OpenAPI spec endpoint is set, the proxied OpenAPI spec is augmented with oidc details."""
3048
app = app_factory(
@@ -39,3 +57,39 @@ def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
3957
assert "openapi" in openapi
4058
assert "paths" in openapi
4159
assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})
60+
61+
62+
def test_oidc_in_openapi_spec_private_endpoints(
63+
source_api: FastAPI, source_api_server: str
64+
):
65+
"""When OpenAPI spec endpoint is set & endpoints are marked private, those endpoints are marked private in the spec."""
66+
private_endpoints = {
67+
# https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods
68+
"/collections": ["POST"],
69+
"/collections/{collection_id}": ["PUT", "PATCH", "DELETE"],
70+
# https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
71+
"/collections/{collection_id}/items": ["POST"],
72+
"/collections/{collection_id}/items/{item_id}": ["PUT", "PATCH", "DELETE"],
73+
# https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
74+
"/collections/{collection_id}/bulk_items": ["POST"],
75+
}
76+
app = app_factory(
77+
upstream_url=source_api_server,
78+
openapi_spec_endpoint=source_api.openapi_url,
79+
private_endpoints=private_endpoints,
80+
)
81+
client = TestClient(app)
82+
openapi = client.get(source_api.openapi_url).raise_for_status().json()
83+
for path, methods in private_endpoints.items():
84+
for method in methods:
85+
openapi_path = openapi["paths"].get(path)
86+
assert openapi_path, f"Path {path} not found in OpenAPI spec"
87+
openapi_path_method = openapi_path.get(method.lower())
88+
assert (
89+
openapi_path_method
90+
), f"Method {method.lower()!r} not found for path {path!r} in OpenAPI spec for path {path}"
91+
security = openapi_path_method.get("security")
92+
assert security, f"Security not found for {path!r} {method.lower()!r}"
93+
assert any(
94+
"oidcAuth" in s for s in security
95+
), f'No "oidcAuth" in security for {path!r} {method.lower()!r}'

0 commit comments

Comments
 (0)