22
33import logging
44import re
5- from dataclasses import dataclass
5+ from dataclasses import dataclass , field
66from itertools import chain
77from typing import Any , Optional
88from urllib .parse import urlparse
99
10+ import httpx
11+ from pydantic import HttpUrl
1012from starlette .requests import Request
1113from starlette .types import ASGIApp
1214
1719logger = logging .getLogger (__name__ )
1820
1921
20- @dataclass ( frozen = True )
22+ @dataclass
2123class AuthenticationExtensionMiddleware (JsonResponseMiddleware ):
2224 """Middleware to add the authentication extension to the response."""
2325
@@ -30,22 +32,48 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
3032 private_endpoints : EndpointMethods
3133 public_endpoints : EndpointMethods
3234
33- signing_scheme : str = "signed_url_auth"
34- auth_scheme : str = "oauth"
35+ oidc_config_url : Optional [HttpUrl ] = None
36+ signing_scheme_name : str = "signed_url_auth"
37+ auth_scheme_name : str = "oauth"
38+ auth_scheme : dict [str , Any ] = field (default_factory = dict )
39+ extension_url : str = (
40+ "https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
41+ )
42+
43+ def __post_init__ (self ):
44+ """Load after initialization."""
45+ if self .oidc_config_url and not self .auth_scheme :
46+ # Retrieve OIDC configuration and extract authorization and token URLs
47+ oidc_config = httpx .get (str (self .oidc_config_url )).json ()
48+ self .auth_scheme = {
49+ "type" : "oauth2" ,
50+ "description" : "requires an authentication token" ,
51+ "flows" : {
52+ "authorizationCode" : {
53+ "authorizationUrl" : oidc_config .get ("authorization_endpoint" ),
54+ "tokenUrl" : oidc_config .get ("token_endpoint" ),
55+ "scopes" : {
56+ k : k
57+ for k in sorted (oidc_config .get ("scopes_supported" , []))
58+ },
59+ },
60+ },
61+ }
3562
3663 def should_transform_response (self , request : Request ) -> bool :
3764 """Determine if the response should be transformed."""
38- print (f"{ request .url = !s} " )
39- return True
65+ # Match STAC catalog, collection, or item URLs with a single regex
66+ return bool (
67+ re .match (
68+ r"^(/|/collections(/[^/]+(/items/[^/]+)?)?|/search)$" , request .url .path
69+ )
70+ )
4071
4172 def transform_json (self , doc : dict [str , Any ]) -> dict [str , Any ]:
4273 """Augment the STAC Item with auth information."""
43- extension = (
44- "https://stac-extensions.github.io/authentication/v1.1.0/schema.json"
45- )
4674 extensions = doc .setdefault ("stac_extensions" , [])
47- if extension not in extensions :
48- extensions .append (extension )
75+ if self . extension_url not in extensions :
76+ extensions .append (self . extension_url )
4977
5078 # TODO: Should we add this to items even if the assets don't match the asset expression?
5179 # auth:schemes
@@ -55,64 +83,41 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
5583 # - Catalogs
5684 # - Collections
5785 # - Item Properties
58- # "auth:schemes": {
59- # "oauth": {
60- # "type": "oauth2",
61- # "description": "requires a login and user token",
62- # "flows": {
63- # "authorizationUrl": "https://example.com/oauth/authorize",
64- # "tokenUrl": "https://example.com/oauth/token",
65- # "scopes": {}
66- # }
67- # }
68- # }
69- # TODO: Add directly to Collections & Catalogs doc
70- if "properties" in doc :
71- schemes = doc ["properties" ].setdefault ("auth:schemes" , {})
72- schemes [self .auth_scheme ] = {
73- "type" : "oauth2" ,
74- "description" : "requires a login and user token" ,
86+ scheme_loc = doc ["properties" ] if "properties" in doc else doc
87+ schemes = scheme_loc .setdefault ("auth:schemes" , {})
88+ schemes [self .auth_scheme_name ] = self .auth_scheme
89+ if self .signing_endpoint :
90+ schemes [self .signing_scheme_name ] = {
91+ "type" : "signedUrl" ,
92+ "description" : "Requires an authentication API" ,
7593 "flows" : {
76- # TODO: Get authorizationUrl and tokenUrl from config
7794 "authorizationCode" : {
78- "authorizationUrl" : "https://example.com/oauth/authorize" ,
79- "tokenUrl" : "https://example.com/oauth/token" ,
80- "scopes" : {},
81- },
82- },
83- }
84- if self .signing_endpoint :
85- schemes [self .signing_scheme ] = {
86- "type" : "signedUrl" ,
87- "description" : "Requires an authentication API" ,
88- "flows" : {
89- "authorizationCode" : {
90- "authorizationApi" : self .signing_endpoint ,
91- "method" : "POST" ,
92- "parameters" : {
93- "bucket" : {
94- "in" : "body" ,
95- "required" : True ,
96- "description" : "asset bucket" ,
97- "schema" : {
98- "type" : "string" ,
99- "examples" : "example-bucket" ,
100- },
95+ "authorizationApi" : self .signing_endpoint ,
96+ "method" : "POST" ,
97+ "parameters" : {
98+ "bucket" : {
99+ "in" : "body" ,
100+ "required" : True ,
101+ "description" : "asset bucket" ,
102+ "schema" : {
103+ "type" : "string" ,
104+ "examples" : "example-bucket" ,
101105 },
102- "key" : {
103- "in " : "body" ,
104- "required " : True ,
105- "description " : "asset key" ,
106- "schema " : {
107- "type " : "string" ,
108- "examples " : "path/to/example/asset.xyz " ,
109- } ,
106+ },
107+ "key " : {
108+ "in " : "body" ,
109+ "required " : True ,
110+ "description " : "asset key" ,
111+ "schema " : {
112+ "type " : "string " ,
113+ "examples" : "path/to/example/asset.xyz" ,
110114 },
111115 },
112- "responseField" : "signed_url" ,
113- }
114- },
115- }
116+ },
117+ "responseField" : "signed_url" ,
118+ }
119+ },
120+ }
116121
117122 # auth:refs
118123 # ---
@@ -123,7 +128,7 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
123128 logger .warning ("Asset %s has no href" , asset )
124129 continue
125130 if re .match (self .signed_asset_expression , asset ["href" ]):
126- asset .setdefault ("auth:refs" , []).append (self .signing_scheme )
131+ asset .setdefault ("auth:refs" , []).append (self .signing_scheme_name )
127132
128133 # Annotate links with "auth:refs": [auth_scheme]
129134 links = chain (
@@ -136,7 +141,6 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
136141 ),
137142 )
138143 for link in links :
139- print (f"{ link ['href' ]= !s} " )
140144 if "href" not in link :
141145 logger .warning ("Link %s has no href" , link )
142146 continue
@@ -148,6 +152,6 @@ def transform_json(self, doc: dict[str, Any]) -> dict[str, Any]:
148152 default_public = self .default_public ,
149153 )
150154 if match .is_private :
151- link .setdefault ("auth:refs" , []).append (self .auth_scheme )
155+ link .setdefault ("auth:refs" , []).append (self .auth_scheme_name )
152156
153157 return doc
0 commit comments