Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions src/stac_auth_proxy/handlers/open_api_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi import Request, Response
from fastapi.routing import APIRoute

from ..utils import safe_headers
from ..utils import has_any_security_requirements, safe_headers
from .reverse_proxy import ReverseProxyHandler

logger = logging.getLogger(__name__)
Expand All @@ -28,26 +28,29 @@ async def dispatch(self, req: Request, res: Response):
# Pass along the response headers
res.headers.update(safe_headers(oidc_spec_response.headers))

# Add the OIDC security scheme to the components
openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[
self.auth_scheme_name
] = {
"type": "openIdConnect",
"openIdConnectUrl": self.oidc_config_url,
}

proxy_auth_routes = [
r
for r in req.app.routes
# Ignore non-APIRoutes (we can't check their security dependencies)
if isinstance(r, APIRoute)
# Ignore routes that don't have security requirements
and (
r.dependant.security_requirements
or any(d.security_requirements for d in r.dependant.dependencies)
)
and has_any_security_requirements(r.dependant)
]

if not proxy_auth_routes:
logger.warning(
"No routes with security requirements found. OIDC security requirements will not be added."
)
return openapi_spec

# Add the OIDC security scheme to the components
openapi_spec.setdefault("components", {}).setdefault("securitySchemes", {})[
self.auth_scheme_name
] = {
"type": "openIdConnect",
"openIdConnectUrl": self.oidc_config_url,
}

# Update the paths with the specified security requirements
for path, method_config in openapi_spec["paths"].items():
for method, config in method_config.items():
Expand All @@ -59,7 +62,7 @@ async def dispatch(self, req: Request, res: Response):
continue
# Add the OIDC security requirement
config.setdefault("security", []).append(
[{self.auth_scheme_name: []}]
{self.auth_scheme_name: []}
)
break

Expand Down
13 changes: 13 additions & 0 deletions src/stac_auth_proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import re
from urllib.parse import urlparse

from fastapi.dependencies.models import Dependant
from httpx import Headers


Expand All @@ -29,3 +30,15 @@ def extract_variables(url: str) -> dict:
pattern = r"^/collections/(?P<collection_id>[^/]+)(?:/(?:items|bulk_items)(?:/(?P<item_id>[^/]+))?)?/?$"
match = re.match(pattern, path)
return {k: v for k, v in match.groupdict().items() if v} if match else {}


def has_any_security_requirements(dependency: Dependant) -> bool:
"""
Recursively check if any dependency within the hierarchy has a non-empty
security_requirements list.
"""
if dependency.security_requirements:
return True
return any(
has_any_security_requirements(sub_dep) for sub_dep in dependency.dependencies
)
47 changes: 39 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Pytest fixtures."""

import json
import os
import threading
from typing import Any
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -68,19 +69,42 @@ def source_api():
app = FastAPI(docs_url="/api.html", openapi_url="/api")

for path, methods in {
"/": ["GET"],
"/conformance": ["GET"],
"/queryables": ["GET"],
"/search": ["GET", "POST"],
"/collections": ["GET", "POST"],
"/collections/{collection_id}": ["GET", "PUT", "DELETE"],
"/collections/{collection_id}/items": ["GET", "POST"],
"/": [
"GET",
],
"/conformance": [
"GET",
],
"/queryables": [
"GET",
],
"/search": [
"GET",
"POST",
],
"/collections": [
"GET",
"POST",
],
"/collections/{collection_id}": [
"GET",
"PUT",
"PATCH",
"DELETE",
],
"/collections/{collection_id}/items": [
"GET",
"POST",
],
"/collections/{collection_id}/items/{item_id}": [
"GET",
"PUT",
"PATCH",
"DELETE",
],
"/collections/{collection_id}/bulk_items": ["POST"],
"/collections/{collection_id}/bulk_items": [
"POST",
],
}.items():
for method in methods:
# NOTE: declare routes per method separately to avoid warning of "Duplicate Operation ID ... for function <lambda>"
Expand Down Expand Up @@ -109,3 +133,10 @@ def source_api_server(source_api):
yield f"http://{host}:{port}"
server.should_exit = True
thread.join()


@pytest.fixture(autouse=True, scope="module")
def mock_env():
"""Clear environment variables to avoid poluting configs from runtime env."""
with patch.dict(os.environ, clear=True):
yield
56 changes: 55 additions & 1 deletion tests/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
)


def test_no_edit_openapi_spec(source_api_server):
def test_no_openapi_spec_endpoint(source_api_server):
"""When no OpenAPI spec endpoint is set, the proxied OpenAPI spec is unaltered."""
app = app_factory(
upstream_url=source_api_server,
Expand All @@ -25,6 +25,24 @@ def test_no_edit_openapi_spec(source_api_server):
assert "oidcAuth" not in openapi.get("components", {}).get("securitySchemes", {})


def test_no_private_endpoints(source_api_server):
"""When no endpoints are private, the proxied OpenAPI spec is unaltered."""
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint="/api",
private_endpoints={},
default_public=True,
)
client = TestClient(app)
response = client.get("/api")
assert response.status_code == 200
openapi = response.json()
assert "info" in openapi
assert "openapi" in openapi
assert "paths" in openapi
assert "oidcAuth" not in openapi.get("components", {}).get("securitySchemes", {})


def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
"""When OpenAPI spec endpoint is set, the proxied OpenAPI spec is augmented with oidc details."""
app = app_factory(
Expand All @@ -39,3 +57,39 @@ def test_oidc_in_openapi_spec(source_api: FastAPI, source_api_server: str):
assert "openapi" in openapi
assert "paths" in openapi
assert "oidcAuth" in openapi.get("components", {}).get("securitySchemes", {})


def test_oidc_in_openapi_spec_private_endpoints(
source_api: FastAPI, source_api_server: str
):
"""When OpenAPI spec endpoint is set & endpoints are marked private, those endpoints are marked private in the spec."""
private_endpoints = {
# https://github.com/stac-api-extensions/collection-transaction/blob/v1.0.0-beta.1/README.md#methods
"/collections": ["POST"],
"/collections/{collection_id}": ["PUT", "PATCH", "DELETE"],
# https://github.com/stac-api-extensions/transaction/blob/v1.0.0-rc.3/README.md#methods
"/collections/{collection_id}/items": ["POST"],
"/collections/{collection_id}/items/{item_id}": ["PUT", "PATCH", "DELETE"],
# https://stac-utils.github.io/stac-fastapi/api/stac_fastapi/extensions/third_party/bulk_transactions/#bulktransactionextension
"/collections/{collection_id}/bulk_items": ["POST"],
}
app = app_factory(
upstream_url=source_api_server,
openapi_spec_endpoint=source_api.openapi_url,
private_endpoints=private_endpoints,
)
client = TestClient(app)
openapi = client.get(source_api.openapi_url).raise_for_status().json()
for path, methods in private_endpoints.items():
for method in methods:
openapi_path = openapi["paths"].get(path)
assert openapi_path, f"Path {path} not found in OpenAPI spec"
openapi_path_method = openapi_path.get(method.lower())
assert (
openapi_path_method
), f"Method {method.lower()!r} not found for path {path!r} in OpenAPI spec for path {path}"
security = openapi_path_method.get("security")
assert security, f"Security not found for {path!r} {method.lower()!r}"
assert any(
"oidcAuth" in s for s in security
), f'No "oidcAuth" in security for {path!r} {method.lower()!r}'
Loading