Skip to content

Commit 27b3552

Browse files
committed
Retrieve oidc info from scope
1 parent 3f0b6af commit 27b3552

File tree

6 files changed

+97
-83
lines changed

6 files changed

+97
-83
lines changed

src/stac_auth_proxy/app.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ async def lifespan(app: FastAPI):
9292
default_public=settings.default_public,
9393
public_endpoints=settings.public_endpoints,
9494
private_endpoints=settings.private_endpoints,
95-
oidc_config_url=settings.oidc_discovery_internal_url,
9695
)
9796

9897
if settings.openapi_spec_endpoint:

src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,12 @@
44
import re
55
from dataclasses import dataclass, field
66
from itertools import chain
7-
from typing import Any, Optional
7+
from typing import Any
88
from urllib.parse import urlparse
99

10-
import httpx
11-
from pydantic import HttpUrl
1210
from starlette.datastructures import Headers
1311
from starlette.requests import Request
14-
from starlette.types import ASGIApp
12+
from starlette.types import ASGIApp, Scope
1513

1614
from ..config import EndpointMethods
1715
from ..utils.middleware import JsonResponseMiddleware
@@ -30,34 +28,13 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
3028
private_endpoints: EndpointMethods
3129
public_endpoints: EndpointMethods
3230

33-
oidc_config_url: Optional[HttpUrl] = None
3431
auth_scheme_name: str = "oauth"
3532
auth_scheme: dict[str, Any] = field(default_factory=dict)
3633
extension_url: str = (
3734
"https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
3835
)
3936

40-
json_content_type_expr: str = r"application/json|geo\+json)"
41-
42-
def __post_init__(self):
43-
"""Load after initialization."""
44-
if self.oidc_config_url and not self.auth_scheme:
45-
# Retrieve OIDC configuration and extract authorization and token URLs
46-
oidc_config = httpx.get(str(self.oidc_config_url)).json()
47-
self.auth_scheme = {
48-
"type": "oauth2",
49-
"description": "requires an authentication token",
50-
"flows": {
51-
"authorizationCode": {
52-
"authorizationUrl": oidc_config.get("authorization_endpoint"),
53-
"tokenUrl": oidc_config.get("token_endpoint"),
54-
"scopes": {
55-
k: k
56-
for k in sorted(oidc_config.get("scopes_supported", []))
57-
},
58-
},
59-
},
60-
}
37+
json_content_type_expr: str = r"(application/json|geo\+json)"
6138

6239
def should_transform_response(
6340
self, request: Request, response_headers: Headers
@@ -78,7 +55,7 @@ def should_transform_response(
7855
]
7956
)
8057

81-
def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
58+
def transform_json(self, doc: dict[str, Any], scope: Scope) -> dict[str, Any]:
8259
"""Augment the STAC Item with auth information."""
8360
extensions = doc.setdefault("stac_extensions", [])
8461
if self.extension_url not in extensions:
@@ -92,9 +69,19 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
9269
# - Catalogs
9370
# - Collections
9471
# - Item Properties
72+
73+
if "oidc_metadata" not in scope:
74+
logger.error(
75+
"OIDC metadata not found in scope. "
76+
"Skipping authentication extension."
77+
)
78+
return doc
79+
9580
scheme_loc = doc["properties"] if "properties" in doc else doc
9681
schemes = scheme_loc.setdefault("auth:schemes", {})
97-
schemes[self.auth_scheme_name] = self.auth_scheme
82+
schemes[self.auth_scheme_name] = self.parse_oidc_config(
83+
scope.get("oidc_metadata", {})
84+
)
9885

9986
# auth:refs
10087
# ---
@@ -125,3 +112,19 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
125112
link.setdefault("auth:refs", []).append(self.auth_scheme_name)
126113

127114
return doc
115+
116+
def parse_oidc_config(self, oidc_config: dict[str, Any]) -> dict[str, Any]:
117+
"""Parse the OIDC configuration."""
118+
return {
119+
"type": "oauth2",
120+
"description": "requires an authentication token",
121+
"flows": {
122+
"authorizationCode": {
123+
"authorizationUrl": oidc_config["authorization_endpoint"],
124+
"tokenUrl": oidc_config.get("token_endpoint"),
125+
"scopes": {
126+
k: k for k in sorted(oidc_config.get("scopes_supported", []))
127+
},
128+
},
129+
},
130+
}

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Middleware to enforce authentication."""
22

33
import logging
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from typing import Annotated, Any, Optional, Sequence
66
from urllib.parse import urlparse, urlunparse
77

@@ -18,6 +18,53 @@
1818
logger = logging.getLogger(__name__)
1919

2020

21+
@dataclass
22+
class OidcService:
23+
"""OIDC configuration and JWKS client."""
24+
25+
oidc_config_url: HttpUrl
26+
jwks_client: jwt.PyJWKClient = field(init=False)
27+
metadata: dict[str, Any] = field(init=False)
28+
29+
def __post_init__(self) -> None:
30+
"""Initialize OIDC config and JWKS client."""
31+
logger.debug("Requesting OIDC config")
32+
origin_url = str(self.oidc_config_url)
33+
34+
try:
35+
response = httpx.get(origin_url)
36+
response.raise_for_status()
37+
self.metadata = response.json()
38+
assert self.metadata, "OIDC metadata is empty"
39+
40+
# NOTE: We manually replace the origin of the jwks_uri in the event that
41+
# the jwks_uri is not available from within the proxy.
42+
oidc_url = urlparse(origin_url)
43+
jwks_uri = urlunparse(
44+
urlparse(self.metadata["jwks_uri"])._replace(
45+
netloc=oidc_url.netloc, scheme=oidc_url.scheme
46+
)
47+
)
48+
if jwks_uri != self.metadata["jwks_uri"]:
49+
logger.warning(
50+
"JWKS URI has been rewritten from %s to %s",
51+
self.metadata["jwks_uri"],
52+
jwks_uri,
53+
)
54+
self.jwks_client = jwt.PyJWKClient(jwks_uri)
55+
except httpx.HTTPStatusError as e:
56+
logger.error(
57+
"Received a non-200 response when fetching OIDC config: %s",
58+
e.response.text,
59+
)
60+
raise OidcFetchError(
61+
f"Request for OIDC config failed with status {e.response.status_code}"
62+
) from e
63+
except httpx.RequestError as e:
64+
logger.error("Error fetching OIDC config from %s: %s", origin_url, str(e))
65+
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
66+
67+
2168
@dataclass
2269
class EnforceAuthMiddleware:
2370
"""Middleware to enforce authentication."""
@@ -26,56 +73,11 @@ class EnforceAuthMiddleware:
2673
private_endpoints: EndpointMethods
2774
public_endpoints: EndpointMethods
2875
default_public: bool
29-
3076
oidc_config_url: HttpUrl
3177
allowed_jwt_audiences: Optional[Sequence[str]] = None
32-
3378
state_key: str = "payload"
3479

35-
# Generated attributes
36-
_jwks_client: Optional[jwt.PyJWKClient] = None
37-
38-
@property
39-
def jwks_client(self) -> jwt.PyJWKClient:
40-
"""Get the OIDC configuration URL."""
41-
if not self._jwks_client:
42-
logger.debug("Requesting OIDC config")
43-
origin_url = str(self.oidc_config_url)
44-
45-
try:
46-
response = httpx.get(origin_url)
47-
response.raise_for_status()
48-
oidc_config = response.json()
49-
50-
# NOTE: We manually replace the origin of the jwks_uri in the event that
51-
# the jwks_uri is not available from within the proxy.
52-
oidc_url = urlparse(origin_url)
53-
jwks_uri = urlunparse(
54-
urlparse(oidc_config["jwks_uri"])._replace(
55-
netloc=oidc_url.netloc, scheme=oidc_url.scheme
56-
)
57-
)
58-
if jwks_uri != oidc_config["jwks_uri"]:
59-
logger.warning(
60-
"JWKS URI has been rewritten from %s to %s",
61-
oidc_config["jwks_uri"],
62-
jwks_uri,
63-
)
64-
self._jwks_client = jwt.PyJWKClient(jwks_uri)
65-
except httpx.HTTPStatusError as e:
66-
logger.error(
67-
"Received a non-200 response when fetching OIDC config: %s",
68-
e.response.text,
69-
)
70-
raise OidcFetchError(
71-
f"Request for OIDC config failed with status {e.response.status_code}"
72-
) from e
73-
except httpx.RequestError as e:
74-
logger.error(
75-
"Error fetching OIDC config from %s: %s", origin_url, str(e)
76-
)
77-
raise OidcFetchError(f"Request for OIDC config failed: {str(e)}") from e
78-
return self._jwks_client
80+
_oidc_config: Optional[OidcService] = None
7981

8082
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
8183
"""Enforce authentication."""
@@ -107,6 +109,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
107109
self.state_key,
108110
payload,
109111
)
112+
setattr(request.state, "oidc_metadata", self.oidc_config.metadata)
110113
return await self.app(scope, receive, send)
111114

112115
def validate_token(
@@ -137,7 +140,7 @@ def validate_token(
137140

138141
# Parse & validate token
139142
try:
140-
key = self.jwks_client.get_signing_key_from_jwt(token).key
143+
key = self.oidc_config.jwks_client.get_signing_key_from_jwt(token).key
141144
payload = jwt.decode(
142145
token,
143146
key,
@@ -163,6 +166,13 @@ def validate_token(
163166
)
164167
return payload
165168

169+
@property
170+
def oidc_config(self) -> OidcService:
171+
"""Get the OIDC configuration."""
172+
if not self._oidc_config:
173+
self._oidc_config = OidcService(oidc_config_url=self.oidc_config_url)
174+
return self._oidc_config
175+
166176

167177
class OidcFetchError(Exception):
168178
"""Error fetching OIDC configuration."""

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from starlette.datastructures import Headers
88
from starlette.requests import Request
9-
from starlette.types import ASGIApp
9+
from starlette.types import ASGIApp, Scope
1010

1111
from ..config import EndpointMethods
1212
from ..utils.middleware import JsonResponseMiddleware
@@ -41,7 +41,9 @@ def should_transform_response(
4141
]
4242
)
4343

44-
def transform_json(self, openapi_spec: dict[str, Any]) -> dict[str, Any]:
44+
def transform_json(
45+
self, openapi_spec: dict[str, Any], scope: Scope
46+
) -> dict[str, Any]:
4547
"""Augment the OpenAPI spec with auth information."""
4648
components = openapi_spec.setdefault("components", {})
4749
securitySchemes = components.setdefault("securitySchemes", {})

src/stac_auth_proxy/utils/middleware.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def should_transform_response(
2929
...
3030

3131
@abstractmethod
32-
def transform_json(self, data: Any) -> Any:
32+
def transform_json(self, data: Any, scope: Scope) -> Any:
3333
"""
3434
Transform the JSON data.
3535
@@ -78,7 +78,7 @@ async def transform_response(message: Message) -> None:
7878
# Transform the JSON body
7979
if body:
8080
data = json.loads(body)
81-
transformed = self.transform_json(data)
81+
transformed = self.transform_json(data, scope=scope)
8282
body = json.dumps(transformed).encode()
8383

8484
# Update content-length header

tests/test_middleware.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from starlette.datastructures import Headers
77
from starlette.requests import Request
88
from starlette.testclient import TestClient
9-
from starlette.types import ASGIApp
9+
from starlette.types import ASGIApp, Scope
1010

1111
from stac_auth_proxy.utils.middleware import JsonResponseMiddleware
1212

@@ -24,7 +24,7 @@ def should_transform_response(
2424
"""Transform JSON responses based on content type."""
2525
return response_headers.get("content-type", "") == "application/json"
2626

27-
def transform_json(self, data: Any) -> Any:
27+
def transform_json(self, data: Any, scope: Scope) -> Any:
2828
"""Add a test field to the response."""
2929
if isinstance(data, dict):
3030
data["transformed"] = True

0 commit comments

Comments
 (0)