Skip to content

Commit bd87d38

Browse files
authored
feat: integrate with Authentication Extension (#41)
Augments relevant documents with: * Adds `"auth:refs": ["oauth"]` onto link for any protected resource * Adds `"auth:schemes"` with description of OIDC server to bottom of document * Adds `https://stac-extensions.github.io/authentication/v1.1.0/schema.json` to `extensions` array --- closes #35
1 parent 9c51ce0 commit bd87d38

12 files changed

+498
-73
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ classifiers = [
66
"License :: OSI Approved :: MIT License",
77
]
88
dependencies = [
9+
"boto3>=1.37.16",
910
"brotli>=1.1.0",
1011
"cql2>=0.3.6",
12+
"cryptography>=44.0.1",
1113
"fastapi>=0.115.5",
1214
"httpx[http2]>=0.28.0",
1315
"jinja2>=3.1.4",

src/stac_auth_proxy/app.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from .middleware import (
1818
AddProcessTimeHeaderMiddleware,
1919
ApplyCql2FilterMiddleware,
20+
AuthenticationExtensionMiddleware,
2021
BuildCql2FilterMiddleware,
2122
EnforceAuthMiddleware,
2223
OpenApiMiddleware,
@@ -86,6 +87,13 @@ async def lifespan(app: FastAPI):
8687
#
8788
# Middleware (order is important, last added = first to run)
8889
#
90+
app.add_middleware(
91+
AuthenticationExtensionMiddleware,
92+
default_public=settings.default_public,
93+
public_endpoints=settings.public_endpoints,
94+
private_endpoints=settings.private_endpoints,
95+
)
96+
8997
if settings.openapi_spec_endpoint:
9098
app.add_middleware(
9199
OpenApiMiddleware,
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
"""Middleware to add auth information to item response served by upstream API."""
2+
3+
import logging
4+
import re
5+
from dataclasses import dataclass, field
6+
from itertools import chain
7+
from typing import Any
8+
from urllib.parse import urlparse
9+
10+
from starlette.datastructures import Headers
11+
from starlette.requests import Request
12+
from starlette.types import ASGIApp
13+
14+
from ..config import EndpointMethods
15+
from ..utils.middleware import JsonResponseMiddleware
16+
from ..utils.requests import find_match
17+
18+
logger = logging.getLogger(__name__)
19+
20+
21+
@dataclass
22+
class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
23+
"""Middleware to add the authentication extension to the response."""
24+
25+
app: ASGIApp
26+
27+
default_public: bool
28+
private_endpoints: EndpointMethods
29+
public_endpoints: EndpointMethods
30+
31+
auth_scheme_name: str = "oauth"
32+
auth_scheme: dict[str, Any] = field(default_factory=dict)
33+
extension_url: str = (
34+
"https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
35+
)
36+
37+
json_content_type_expr: str = r"application/(geo\+)?json"
38+
39+
state_key: str = "oidc_metadata"
40+
41+
def should_transform_response(
42+
self, request: Request, response_headers: Headers
43+
) -> bool:
44+
"""Determine if the response should be transformed."""
45+
# Match STAC catalog, collection, or item URLs with a single regex
46+
return all(
47+
[
48+
re.match(
49+
# catalog, collections, collection, items, item, search
50+
r"^(/|/collections(/[^/]+(/items(/[^/]+)?)?)?|/search)$",
51+
request.url.path,
52+
),
53+
re.match(
54+
self.json_content_type_expr,
55+
response_headers.get("content-type", ""),
56+
),
57+
]
58+
)
59+
60+
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
61+
"""Augment the STAC Item with auth information."""
62+
extensions = data.setdefault("stac_extensions", [])
63+
if self.extension_url not in extensions:
64+
extensions.append(self.extension_url)
65+
66+
# auth:schemes
67+
# ---
68+
# A property that contains all of the scheme definitions used by Assets and
69+
# Links in the STAC Item or Collection.
70+
# - Catalogs
71+
# - Collections
72+
# - Item Properties
73+
74+
oidc_metadata = getattr(request.state, self.state_key, {})
75+
if not oidc_metadata:
76+
logger.error(
77+
"OIDC metadata not found in scope. Skipping authentication extension."
78+
)
79+
return data
80+
81+
scheme_loc = data["properties"] if "properties" in data else data
82+
schemes = scheme_loc.setdefault("auth:schemes", {})
83+
schemes[self.auth_scheme_name] = {
84+
"type": "oauth2",
85+
"description": "requires an authentication bearertoken",
86+
"flows": {
87+
"authorizationCode": {
88+
"authorizationUrl": oidc_metadata["authorization_endpoint"],
89+
"tokenUrl": oidc_metadata.get("token_endpoint"),
90+
"scopes": {
91+
k: k for k in sorted(oidc_metadata.get("scopes_supported", []))
92+
},
93+
},
94+
},
95+
}
96+
97+
# auth:refs
98+
# ---
99+
# Annotate links with "auth:refs": [auth_scheme]
100+
links = chain(
101+
# Item/Collection
102+
data.get("links", []),
103+
# Collections/Items/Search
104+
(
105+
link
106+
for prop in ["features", "collections"]
107+
for object_with_links in data.get(prop, [])
108+
for link in object_with_links.get("links", [])
109+
),
110+
)
111+
for link in links:
112+
if "href" not in link:
113+
logger.warning("Link %s has no href", link)
114+
continue
115+
match = find_match(
116+
path=urlparse(link["href"]).path,
117+
method="GET",
118+
private_endpoints=self.private_endpoints,
119+
public_endpoints=self.public_endpoints,
120+
default_public=self.default_public,
121+
)
122+
if match.is_private:
123+
link.setdefault("auth:refs", []).append(self.auth_scheme_name)
124+
125+
return data

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Middleware to enforce authentication."""
22

33
import logging
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from typing import Annotated, Any, Optional, Sequence
66
from urllib.parse import urlparse, urlunparse
77

@@ -18,6 +18,53 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21+
@dataclass
22+
class OidcService:
23+
"""OIDC configuration and JWKS client."""
24+
25+
oidc_config_url: HttpUrl
26+
jwks_client: jwt.PyJWKClient = field(init=False)
27+
metadata: dict[str, Any] = field(init=False)
28+
29+
def __post_init__(self) -> None:
30+
"""Initialize OIDC config and JWKS client."""
31+
logger.debug("Requesting OIDC config")
32+
origin_url = str(self.oidc_config_url)
33+
34+
try:
35+
response = httpx.get(origin_url)
36+
response.raise_for_status()
37+
self.metadata = response.json()
38+
assert self.metadata, "OIDC metadata is empty"
39+
40+
# NOTE: We manually replace the origin of the jwks_uri in the event that
41+
# the jwks_uri is not available from within the proxy.
42+
oidc_url = urlparse(origin_url)
43+
jwks_uri = urlunparse(
44+
urlparse(self.metadata["jwks_uri"])._replace(
45+
netloc=oidc_url.netloc, scheme=oidc_url.scheme
46+
)
47+
)
48+
if jwks_uri != self.metadata["jwks_uri"]:
49+
logger.warning(
50+
"JWKS URI has been rewritten from %s to %s",
51+
self.metadata["jwks_uri"],
52+
jwks_uri,
53+
)
54+
self.jwks_client = jwt.PyJWKClient(jwks_uri)
55+
except httpx.HTTPStatusError as e:
56+
logger.error(
57+
"Received a non-200 response when fetching OIDC config: %s",
58+
e.response.text,
59+
)
60+
raise OidcFetchError(
61+
f"Request for OIDC config failed with status {e.response.status_code}"
62+
) from e
63+
except httpx.RequestError as e:
64+
logger.error("Error fetching OIDC config from %s: %s", origin_url, str(e))
65+
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
66+
67+
2168
@dataclass
2269
class EnforceAuthMiddleware:
2370
"""Middleware to enforce authentication."""
@@ -26,56 +73,11 @@ class EnforceAuthMiddleware:
2673
private_endpoints: EndpointMethods
2774
public_endpoints: EndpointMethods
2875
default_public: bool
29-
3076
oidc_config_url: HttpUrl
3177
allowed_jwt_audiences: Optional[Sequence[str]] = None
32-
3378
state_key: str = "payload"
3479

35-
# Generated attributes
36-
_jwks_client: Optional[jwt.PyJWKClient] = None
37-
38-
@property
39-
def jwks_client(self) -> jwt.PyJWKClient:
40-
"""Get the OIDC configuration URL."""
41-
if not self._jwks_client:
42-
logger.debug("Requesting OIDC config")
43-
origin_url = str(self.oidc_config_url)
44-
45-
try:
46-
response = httpx.get(origin_url)
47-
response.raise_for_status()
48-
oidc_config = response.json()
49-
50-
# NOTE: We manually replace the origin of the jwks_uri in the event that
51-
# the jwks_uri is not available from within the proxy.
52-
oidc_url = urlparse(origin_url)
53-
jwks_uri = urlunparse(
54-
urlparse(oidc_config["jwks_uri"])._replace(
55-
netloc=oidc_url.netloc, scheme=oidc_url.scheme
56-
)
57-
)
58-
if jwks_uri != oidc_config["jwks_uri"]:
59-
logger.warning(
60-
"JWKS URI has been rewritten from %s to %s",
61-
oidc_config["jwks_uri"],
62-
jwks_uri,
63-
)
64-
self._jwks_client = jwt.PyJWKClient(jwks_uri)
65-
except httpx.HTTPStatusError as e:
66-
logger.error(
67-
"Received a non-200 response when fetching OIDC config: %s",
68-
e.response.text,
69-
)
70-
raise OidcFetchError(
71-
f"Request for OIDC config failed with status {e.response.status_code}"
72-
) from e
73-
except httpx.RequestError as e:
74-
logger.error(
75-
"Error fetching OIDC config from %s: %s", origin_url, str(e)
76-
)
77-
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
78-
return self._jwks_client
80+
_oidc_config: Optional[OidcService] = None
7981

8082
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
8183
"""Enforce authentication."""
@@ -107,6 +109,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
107109
self.state_key,
108110
payload,
109111
)
112+
setattr(request.state, "oidc_metadata", self.oidc_config.metadata)
110113
return await self.app(scope, receive, send)
111114

112115
def validate_token(
@@ -137,7 +140,7 @@ def validate_token(
137140

138141
# Parse & validate token
139142
try:
140-
key = self.jwks_client.get_signing_key_from_jwt(token).key
143+
key = self.oidc_config.jwks_client.get_signing_key_from_jwt(token).key
141144
payload = jwt.decode(
142145
token,
143146
key,
@@ -163,6 +166,13 @@ def validate_token(
163166
)
164167
return payload
165168

169+
@property
170+
def oidc_config(self) -> OidcService:
171+
"""Get the OIDC configuration."""
172+
if not self._oidc_config:
173+
self._oidc_config = OidcService(oidc_config_url=self.oidc_config_url)
174+
return self._oidc_config
175+
166176

167177
class OidcFetchError(Exception):
168178
"""Error fetching OIDC configuration."""

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ def should_transform_response(
4141
]
4242
)
4343

44-
def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]:
44+
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
4545
"""Augment the OpenAPI spec with auth information."""
46-
components = openapi_spec.setdefault("components", {})
46+
components = data.setdefault("components", {})
4747
securitySchemes = components.setdefault("securitySchemes", {})
4848
securitySchemes[self.oidc_auth_scheme_name] = {
4949
"type": "openIdConnect",
5050
"openIdConnectUrl": self.oidc_config_url,
5151
}
52-
for path, method_config in openapi_spec["paths"].items():
52+
for path, method_config in data["paths"].items():
5353
for method, config in method_config.items():
5454
match = find_match(
5555
path,
@@ -62,4 +62,4 @@ def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]:
6262
config.setdefault("security", []).append(
6363
{self.oidc_auth_scheme_name: match.required_scopes}
6464
)
65-
return openapi_spec
65+
return data

src/stac_auth_proxy/middleware/__init__.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,16 @@
22

33
from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware
44
from .ApplyCql2FilterMiddleware import ApplyCql2FilterMiddleware
5+
from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware
56
from .BuildCql2FilterMiddleware import BuildCql2FilterMiddleware
67
from .EnforceAuthMiddleware import EnforceAuthMiddleware
78
from .UpdateOpenApiMiddleware import OpenApiMiddleware
89

910
__all__ = [
10-
x.__name__
11-
for x in [
12-
OpenApiMiddleware,
13-
AddProcessTimeHeaderMiddleware,
14-
EnforceAuthMiddleware,
15-
BuildCql2FilterMiddleware,
16-
ApplyCql2FilterMiddleware,
17-
]
11+
"AddProcessTimeHeaderMiddleware",
12+
"ApplyCql2FilterMiddleware",
13+
"AuthenticationExtensionMiddleware",
14+
"BuildCql2FilterMiddleware",
15+
"EnforceAuthMiddleware",
16+
"OpenApiMiddleware",
1817
]

src/stac_auth_proxy/utils/middleware.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def should_transform_response(
2929
...
3030

3131
@abstractmethod
32-
def transform_json(self, data: Any) -> Any:
32+
def transform_json(self, data: Any, request: Request) -> Any:
3333
"""
3434
Transform the JSON data.
3535
@@ -56,9 +56,10 @@ async def transform_response(message: Message) -> None:
5656

5757
start_message = start_message or message
5858
headers = MutableHeaders(scope=start_message)
59+
request = Request(scope)
5960

6061
if not self.should_transform_response(
61-
request=Request(scope),
62+
request=request,
6263
response_headers=headers,
6364
):
6465
# For non-JSON responses, send the start message immediately
@@ -78,7 +79,7 @@ async def transform_response(message: Message) -> None:
7879
# Transform the JSON body
7980
if body:
8081
data = json.loads(body)
81-
transformed = self.transform_json(data)
82+
transformed = self.transform_json(data, request=request)
8283
body = json.dumps(transformed).encode()
8384

8485
# Update content-length header

0 commit comments

Comments
 (0)