Skip to content

Commit 2edc674

Browse files
committed
Merge branch 'main' into authentication-ext/asset-signing
2 parents fdfa62c + bd87d38 commit 2edc674

File tree

11 files changed

+303
-100
lines changed

11 files changed

+303
-100
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ dependencies = [
99
"boto3>=1.37.16",
1010
"brotli>=1.1.0",
1111
"cql2>=0.3.6",
12+
"cryptography>=44.0.1",
1213
"fastapi>=0.115.5",
1314
"httpx[http2]>=0.28.0",
1415
"jinja2>=3.1.4",

src/stac_auth_proxy/app.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ async def lifespan(app: FastAPI):
102102
default_public=settings.default_public,
103103
public_endpoints=settings.public_endpoints,
104104
private_endpoints=settings.private_endpoints,
105-
oidc_config_url=settings.oidc_discovery_internal_url,
106105
)
107106

108107
if settings.openapi_spec_endpoint:

src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from typing import Any, Optional
88
from urllib.parse import urlparse
99

10-
import httpx
11-
from pydantic import HttpUrl
1210
from starlette.datastructures import Headers
1311
from starlette.requests import Request
1412
from starlette.types import ASGIApp
@@ -33,35 +31,16 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
3331
private_endpoints: EndpointMethods
3432
public_endpoints: EndpointMethods
3533

36-
oidc_config_url: Optional[HttpUrl] = None
3734
signing_scheme_name: str = "signed_url_auth"
3835
auth_scheme_name: str = "oauth"
3936
auth_scheme: dict[str, Any] = field(default_factory=dict)
4037
extension_url: str = (
4138
"https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
4239
)
4340

44-
json_content_type_expr: str = r"application/json|geo\+json)"
41+
json_content_type_expr: str = r"application/(geo\+)?json"
4542

46-
def __post_init__(self):
47-
"""Load after initialization."""
48-
if self.oidc_config_url and not self.auth_scheme:
49-
# Retrieve OIDC configuration and extract authorization and token URLs
50-
oidc_config = httpx.get(str(self.oidc_config_url)).json()
51-
self.auth_scheme = {
52-
"type": "oauth2",
53-
"description": "requires an authentication token",
54-
"flows": {
55-
"authorizationCode": {
56-
"authorizationUrl": oidc_config.get("authorization_endpoint"),
57-
"tokenUrl": oidc_config.get("token_endpoint"),
58-
"scopes": {
59-
k: k
60-
for k in sorted(oidc_config.get("scopes_supported", []))
61-
},
62-
},
63-
},
64-
}
43+
state_key: str = "oidc_metadata"
6544

6645
def should_transform_response(
6746
self, request: Request, response_headers: Headers
@@ -82,23 +61,42 @@ def should_transform_response(
8261
]
8362
)
8463

85-
def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
64+
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
8665
"""Augment the STAC Item with auth information."""
87-
extensions = doc.setdefault("stac_extensions", [])
66+
extensions = data.setdefault("stac_extensions", [])
8867
if self.extension_url not in extensions:
8968
extensions.append(self.extension_url)
9069

91-
# TODO: Should we add this to items even if the assets don't match the asset expression?
9270
# auth:schemes
9371
# ---
9472
# A property that contains all of the scheme definitions used by Assets and
9573
# Links in the STAC Item or Collection.
9674
# - Catalogs
9775
# - Collections
9876
# - Item Properties
99-
scheme_loc = doc["properties"] if "properties" in doc else doc
77+
78+
oidc_metadata = getattr(request.state, self.state_key, {})
79+
if not oidc_metadata:
80+
logger.error(
81+
"OIDC metadata not found in scope. Skipping authentication extension."
82+
)
83+
return data
84+
85+
scheme_loc = data["properties"] if "properties" in data else data
10086
schemes = scheme_loc.setdefault("auth:schemes", {})
101-
schemes[self.auth_scheme_name] = self.auth_scheme
87+
schemes[self.auth_scheme_name] = {
88+
"type": "oauth2",
89+
"description": "requires an authentication bearertoken",
90+
"flows": {
91+
"authorizationCode": {
92+
"authorizationUrl": oidc_metadata["authorization_endpoint"],
93+
"tokenUrl": oidc_metadata.get("token_endpoint"),
94+
"scopes": {
95+
k: k for k in sorted(oidc_metadata.get("scopes_supported", []))
96+
},
97+
},
98+
},
99+
}
102100
if self.signing_endpoint:
103101
schemes[self.signing_scheme_name] = {
104102
"type": "signedUrl",
@@ -138,11 +136,11 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
138136
if self.signing_endpoint:
139137
assets = chain(
140138
# Item
141-
doc.get("assets", {}).values(),
139+
data.get("assets", {}).values(),
142140
# Items/Search
143141
(
144142
asset
145-
for item in doc.get("features", [])
143+
for item in data.get("features", [])
146144
for asset in item.get("assets", {}).values()
147145
),
148146
)
@@ -152,16 +150,15 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
152150
continue
153151
if re.match(self.signed_asset_expression, asset["href"]):
154152
asset.setdefault("auth:refs", []).append(self.signing_scheme_name)
155-
156153
# Annotate links with "auth:refs": [auth_scheme]
157154
links = chain(
158155
# Item/Collection
159-
doc.get("links", []),
156+
data.get("links", []),
160157
# Collections/Items/Search
161158
(
162159
link
163160
for prop in ["features", "collections"]
164-
for object_with_links in doc.get(prop, [])
161+
for object_with_links in data.get(prop, [])
165162
for link in object_with_links.get("links", [])
166163
),
167164
)
@@ -179,4 +176,4 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
179176
if match.is_private:
180177
link.setdefault("auth:refs", []).append(self.auth_scheme_name)
181178

182-
return doc
179+
return data

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Middleware to enforce authentication."""
22

33
import logging
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from typing import Annotated, Any, Optional, Sequence
66
from urllib.parse import urlparse, urlunparse
77

@@ -18,6 +18,53 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21+
@dataclass
22+
class OidcService:
23+
"""OIDC configuration and JWKS client."""
24+
25+
oidc_config_url: HttpUrl
26+
jwks_client: jwt.PyJWKClient = field(init=False)
27+
metadata: dict[str, Any] = field(init=False)
28+
29+
def __post_init__(self) -> None:
30+
"""Initialize OIDC config and JWKS client."""
31+
logger.debug("Requesting OIDC config")
32+
origin_url = str(self.oidc_config_url)
33+
34+
try:
35+
response = httpx.get(origin_url)
36+
response.raise_for_status()
37+
self.metadata = response.json()
38+
assert self.metadata, "OIDC metadata is empty"
39+
40+
# NOTE: We manually replace the origin of the jwks_uri in the event that
41+
# the jwks_uri is not available from within the proxy.
42+
oidc_url = urlparse(origin_url)
43+
jwks_uri = urlunparse(
44+
urlparse(self.metadata["jwks_uri"])._replace(
45+
netloc=oidc_url.netloc, scheme=oidc_url.scheme
46+
)
47+
)
48+
if jwks_uri != self.metadata["jwks_uri"]:
49+
logger.warning(
50+
"JWKS URI has been rewritten from %s to %s",
51+
self.metadata["jwks_uri"],
52+
jwks_uri,
53+
)
54+
self.jwks_client = jwt.PyJWKClient(jwks_uri)
55+
except httpx.HTTPStatusError as e:
56+
logger.error(
57+
"Received a non-200 response when fetching OIDC config: %s",
58+
e.response.text,
59+
)
60+
raise OidcFetchError(
61+
f"Request for OIDC config failed with status {e.response.status_code}"
62+
) from e
63+
except httpx.RequestError as e:
64+
logger.error("Error fetching OIDC config from %s: %s", origin_url, str(e))
65+
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
66+
67+
2168
@dataclass
2269
class EnforceAuthMiddleware:
2370
"""Middleware to enforce authentication."""
@@ -26,56 +73,11 @@ class EnforceAuthMiddleware:
2673
private_endpoints: EndpointMethods
2774
public_endpoints: EndpointMethods
2875
default_public: bool
29-
3076
oidc_config_url: HttpUrl
3177
allowed_jwt_audiences: Optional[Sequence[str]] = None
32-
3378
state_key: str = "payload"
3479

35-
# Generated attributes
36-
_jwks_client: Optional[jwt.PyJWKClient] = None
37-
38-
@property
39-
def jwks_client(self) -> jwt.PyJWKClient:
40-
"""Get the OIDC configuration URL."""
41-
if not self._jwks_client:
42-
logger.debug("Requesting OIDC config")
43-
origin_url = str(self.oidc_config_url)
44-
45-
try:
46-
response = httpx.get(origin_url)
47-
response.raise_for_status()
48-
oidc_config = response.json()
49-
50-
# NOTE: We manually replace the origin of the jwks_uri in the event that
51-
# the jwks_uri is not available from within the proxy.
52-
oidc_url = urlparse(origin_url)
53-
jwks_uri = urlunparse(
54-
urlparse(oidc_config["jwks_uri"])._replace(
55-
netloc=oidc_url.netloc, scheme=oidc_url.scheme
56-
)
57-
)
58-
if jwks_uri != oidc_config["jwks_uri"]:
59-
logger.warning(
60-
"JWKS URI has been rewritten from %s to %s",
61-
oidc_config["jwks_uri"],
62-
jwks_uri,
63-
)
64-
self._jwks_client = jwt.PyJWKClient(jwks_uri)
65-
except httpx.HTTPStatusError as e:
66-
logger.error(
67-
"Received a non-200 response when fetching OIDC config: %s",
68-
e.response.text,
69-
)
70-
raise OidcFetchError(
71-
f"Request for OIDC config failed with status {e.response.status_code}"
72-
) from e
73-
except httpx.RequestError as e:
74-
logger.error(
75-
"Error fetching OIDC config from %s: %s", origin_url, str(e)
76-
)
77-
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
78-
return self._jwks_client
80+
_oidc_config: Optional[OidcService] = None
7981

8082
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
8183
"""Enforce authentication."""
@@ -107,6 +109,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
107109
self.state_key,
108110
payload,
109111
)
112+
setattr(request.state, "oidc_metadata", self.oidc_config.metadata)
110113
return await self.app(scope, receive, send)
111114

112115
def validate_token(
@@ -137,7 +140,7 @@ def validate_token(
137140

138141
# Parse & validate token
139142
try:
140-
key = self.jwks_client.get_signing_key_from_jwt(token).key
143+
key = self.oidc_config.jwks_client.get_signing_key_from_jwt(token).key
141144
payload = jwt.decode(
142145
token,
143146
key,
@@ -163,6 +166,13 @@ def validate_token(
163166
)
164167
return payload
165168

169+
@property
170+
def oidc_config(self) -> OidcService:
171+
"""Get the OIDC configuration."""
172+
if not self._oidc_config:
173+
self._oidc_config = OidcService(oidc_config_url=self.oidc_config_url)
174+
return self._oidc_config
175+
166176

167177
class OidcFetchError(Exception):
168178
"""Error fetching OIDC configuration."""

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ def should_transform_response(
4141
]
4242
)
4343

44-
def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]:
44+
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
4545
"""Augment the OpenAPI spec with auth information."""
46-
components = openapi_spec.setdefault("components", {})
46+
components = data.setdefault("components", {})
4747
securitySchemes = components.setdefault("securitySchemes", {})
4848
securitySchemes[self.oidc_auth_scheme_name] = {
4949
"type": "openIdConnect",
5050
"openIdConnectUrl": self.oidc_config_url,
5151
}
52-
for path, method_config in openapi_spec["paths"].items():
52+
for path, method_config in data["paths"].items():
5353
for method, config in method_config.items():
5454
match = find_match(
5555
path,
@@ -62,4 +62,4 @@ def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]:
6262
config.setdefault("security", []).append(
6363
{self.oidc_auth_scheme_name: match.required_scopes}
6464
)
65-
return openapi_spec
65+
return data

src/stac_auth_proxy/utils/middleware.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def should_transform_response(
2929
...
3030

3131
@abstractmethod
32-
def transform_json(self, data: Any) -> Any:
32+
def transform_json(self, data: Any, request: Request) -> Any:
3333
"""
3434
Transform the JSON data.
3535
@@ -56,9 +56,10 @@ async def transform_response(message: Message) -> None:
5656

5757
start_message = start_message or message
5858
headers = MutableHeaders(scope=start_message)
59+
request = Request(scope)
5960

6061
if not self.should_transform_response(
61-
request=Request(scope),
62+
request=request,
6263
response_headers=headers,
6364
):
6465
# For non-JSON responses, send the start message immediately
@@ -78,7 +79,7 @@ async def transform_response(message: Message) -> None:
7879
# Transform the JSON body
7980
if body:
8081
data = json.loads(body)
81-
transformed = self.transform_json(data)
82+
transformed = self.transform_json(data, request=request)
8283
body = json.dumps(transformed).encode()
8384

8485
# Update content-length header

tests/conftest.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@ def public_key(test_key: jwk.JWK) -> dict[str, Any]:
3232
@pytest.fixture(autouse=True)
3333
def mock_jwks(public_key: dict[str, Any]):
3434
"""Mock JWKS endpoint."""
35-
mock_oidc_config = {"jwks_uri": "https://example.com/jwks"}
35+
mock_oidc_config = {
36+
"jwks_uri": "https://example.com/jwks",
37+
"authorization_endpoint": "https://example.com/auth",
38+
"token_endpoint": "https://example.com/token",
39+
"scopes_supported": ["openid", "profile", "email", "collection:create"],
40+
}
3641

3742
mock_jwks = {"keys": [public_key]}
3843

0 commit comments

Comments
 (0)