Skip to content

Commit 1ce8ed5

Browse files
committed
feat: update authentication extension integratino to use discovery URL
Along the way, normalize all OIDC discovery url input arguments to match config.
1 parent adbed6b commit 1ce8ed5

File tree

5 files changed

+26
-57
lines changed

5 files changed

+26
-57
lines changed

src/stac_auth_proxy/app.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,13 +103,14 @@ async def lifespan(app: FastAPI):
103103
default_public=settings.default_public,
104104
public_endpoints=settings.public_endpoints,
105105
private_endpoints=settings.private_endpoints,
106+
oidc_discovery_url=str(settings.oidc_discovery_url),
106107
)
107108

108109
if settings.openapi_spec_endpoint:
109110
app.add_middleware(
110111
OpenApiMiddleware,
111112
openapi_spec_path=settings.openapi_spec_endpoint,
112-
oidc_config_url=str(settings.oidc_discovery_url),
113+
oidc_discovery_url=str(settings.oidc_discovery_url),
113114
public_endpoints=settings.public_endpoints,
114115
private_endpoints=settings.private_endpoints,
115116
default_public=settings.default_public,
@@ -136,7 +137,7 @@ async def lifespan(app: FastAPI):
136137
public_endpoints=settings.public_endpoints,
137138
private_endpoints=settings.private_endpoints,
138139
default_public=settings.default_public,
139-
oidc_config_url=settings.oidc_discovery_internal_url,
140+
oidc_discovery_url=settings.oidc_discovery_internal_url,
140141
)
141142

142143
if settings.root_path or settings.upstream_url.path != "/":

src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py

Lines changed: 4 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,15 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
2828
private_endpoints: EndpointMethods
2929
public_endpoints: EndpointMethods
3030

31-
auth_scheme_name: str = "oauth"
31+
oidc_discovery_url: str
32+
auth_scheme_name: str = "oidc"
3233
auth_scheme: dict[str, Any] = field(default_factory=dict)
3334
extension_url: str = (
3435
"https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
3536
)
3637

3738
json_content_type_expr: str = r"application/(geo\+)?json"
3839

39-
state_key: str = "oidc_metadata"
40-
4140
def should_transform_response(self, request: Request, scope: Scope) -> bool:
4241
"""Determine if the response should be transformed."""
4342
# Match STAC catalog, collection, or item URLs with a single regex
@@ -75,27 +74,11 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
7574
# - Collections
7675
# - Item Properties
7776

78-
oidc_metadata = getattr(request.state, self.state_key, {})
79-
if not oidc_metadata:
80-
logger.error(
81-
"OIDC metadata not found in scope. Skipping authentication extension."
82-
)
83-
return data
84-
8577
scheme_loc = data["properties"] if "properties" in data else data
8678
schemes = scheme_loc.setdefault("auth:schemes", {})
8779
schemes[self.auth_scheme_name] = {
88-
"type": "oauth2",
89-
"description": "requires an authentication bearertoken",
90-
"flows": {
91-
"authorizationCode": {
92-
"authorizationUrl": oidc_metadata["authorization_endpoint"],
93-
"tokenUrl": oidc_metadata.get("token_endpoint"),
94-
"scopes": {
95-
k: k for k in sorted(oidc_metadata.get("scopes_supported", []))
96-
},
97-
},
98-
},
80+
"type": "openIdConnect",
81+
"openIdConnectUrl": self.oidc_discovery_url,
9982
}
10083

10184
# auth:refs

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,14 @@
2222
class OidcService:
2323
"""OIDC configuration and JWKS client."""
2424

25-
oidc_config_url: HttpUrl
25+
oidc_discovery_url: HttpUrl
2626
jwks_client: jwt.PyJWKClient = field(init=False)
2727
metadata: dict[str, Any] = field(init=False)
2828

2929
def __post_init__(self) -> None:
3030
"""Initialize OIDC config and JWKS client."""
3131
logger.debug("Requesting OIDC config")
32-
origin_url = str(self.oidc_config_url)
32+
origin_url = str(self.oidc_discovery_url)
3333

3434
try:
3535
response = httpx.get(origin_url)
@@ -73,7 +73,7 @@ class EnforceAuthMiddleware:
7373
private_endpoints: EndpointMethods
7474
public_endpoints: EndpointMethods
7575
default_public: bool
76-
oidc_config_url: HttpUrl
76+
oidc_discovery_url: HttpUrl
7777
allowed_jwt_audiences: Optional[Sequence[str]] = None
7878
state_key: str = "payload"
7979

@@ -170,7 +170,7 @@ def validate_token(
170170
def oidc_config(self) -> OidcService:
171171
"""Get the OIDC configuration."""
172172
if not self._oidc_config:
173-
self._oidc_config = OidcService(oidc_config_url=self.oidc_config_url)
173+
self._oidc_config = OidcService(oidc_discovery_url=self.oidc_discovery_url)
174174
return self._oidc_config
175175

176176

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ class OpenApiMiddleware(JsonResponseMiddleware):
1919

2020
app: ASGIApp
2121
openapi_spec_path: str
22-
oidc_config_url: str
22+
oidc_discovery_url: str
2323
private_endpoints: EndpointMethods
2424
public_endpoints: EndpointMethods
2525
default_public: bool
@@ -56,7 +56,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
5656
securitySchemes = components.setdefault("securitySchemes", {})
5757
securitySchemes[self.auth_scheme_name] = self.auth_scheme_override or {
5858
"type": "openIdConnect",
59-
"openIdConnectUrl": self.oidc_config_url,
59+
"openIdConnectUrl": self.oidc_discovery_url,
6060
}
6161

6262
# Add security to private endpoints

tests/test_auth_extension.py

Lines changed: 13 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,20 @@
1010

1111

1212
@pytest.fixture
13-
def middleware():
13+
def oidc_discovery_url():
14+
"""Create test OIDC discovery URL."""
15+
return "https://auth.example.com/discovery"
16+
17+
18+
@pytest.fixture
19+
def middleware(oidc_discovery_url):
1420
"""Create a test instance of the middleware."""
1521
return AuthenticationExtensionMiddleware(
1622
app=None, # We don't need the actual app for these tests
1723
default_public=True,
1824
private_endpoints=EndpointMethods(),
1925
public_endpoints=EndpointMethods(),
26+
oidc_discovery_url=oidc_discovery_url,
2027
auth_scheme_name="test_auth",
2128
auth_scheme={},
2229
)
@@ -49,16 +56,6 @@ def initial_message(request):
4956
}
5057

5158

52-
@pytest.fixture
53-
def oidc_metadata():
54-
"""Create test OIDC metadata."""
55-
return {
56-
"authorization_endpoint": "https://auth.example.com/auth",
57-
"token_endpoint": "https://auth.example.com/token",
58-
"scopes_supported": ["openid", "profile"],
59-
}
60-
61-
6259
def test_should_transform_response_valid_paths(
6360
middleware, request_scope, initial_message
6461
):
@@ -113,10 +110,9 @@ def test_should_transform_response_invalid_content_type(middleware, request_scop
113110
)
114111

115112

116-
def test_transform_json_catalog(middleware, request_scope, oidc_metadata):
113+
def test_transform_json_catalog(middleware, request_scope, oidc_discovery_url):
117114
"""Test transforming a STAC catalog."""
118115
request = Request(request_scope)
119-
request.state.oidc_metadata = oidc_metadata
120116

121117
catalog = {
122118
"stac_version": "1.0.0",
@@ -136,23 +132,13 @@ def test_transform_json_catalog(middleware, request_scope, oidc_metadata):
136132
assert "test_auth" in transformed["auth:schemes"]
137133

138134
scheme = transformed["auth:schemes"]["test_auth"]
139-
assert scheme["type"] == "oauth2"
140-
assert (
141-
scheme["flows"]["authorizationCode"]["authorizationUrl"]
142-
== oidc_metadata["authorization_endpoint"]
143-
)
144-
assert (
145-
scheme["flows"]["authorizationCode"]["tokenUrl"]
146-
== oidc_metadata["token_endpoint"]
147-
)
148-
assert "openid" in scheme["flows"]["authorizationCode"]["scopes"]
149-
assert "profile" in scheme["flows"]["authorizationCode"]["scopes"]
135+
assert scheme["type"] == "openIdConnect"
136+
assert scheme["openIdConnectUrl"] == oidc_discovery_url
150137

151138

152-
def test_transform_json_collection(middleware, request_scope, oidc_metadata):
139+
def test_transform_json_collection(middleware, request_scope):
153140
"""Test transforming a STAC collection."""
154141
request = Request(request_scope)
155-
request.state.oidc_metadata = oidc_metadata
156142

157143
collection = {
158144
"stac_version": "1.0.0",
@@ -173,10 +159,9 @@ def test_transform_json_collection(middleware, request_scope, oidc_metadata):
173159
assert "test_auth" in transformed["auth:schemes"]
174160

175161

176-
def test_transform_json_item(middleware, request_scope, oidc_metadata):
162+
def test_transform_json_item(middleware, request_scope):
177163
"""Test transforming a STAC item."""
178164
request = Request(request_scope)
179-
request.state.oidc_metadata = oidc_metadata
180165

181166
item = {
182167
"stac_version": "1.0.0",

0 commit comments

Comments
 (0)