99
1010from starlette .datastructures import Headers
1111from starlette .requests import Request
12- from starlette .types import ASGIApp , Scope
12+ from starlette .types import ASGIApp
1313
1414from ..config import EndpointMethods
1515from ..utils .middleware import JsonResponseMiddleware
@@ -36,6 +36,8 @@ class AuthenticationExtensionMiddleware(JsonResponseMiddleware):
3636
3737 json_content_type_expr : str = r"(application/json|geo\+json)"
3838
39+ state_key : str = "oidc_metadata"
40+
3941 def should_transform_response (
4042 self , request : Request , response_headers : Headers
4143 ) -> bool :
@@ -55,9 +57,9 @@ def should_transform_response(
5557 ]
5658 )
5759
58- def transform_json (self , doc : dict [str , Any ], scope : Scope ) -> dict [str , Any ]:
60+ def transform_json (self , data : dict [str , Any ], request : Request ) -> dict [str , Any ]:
5961 """Augment the STAC Item with auth information."""
60- extensions = doc .setdefault ("stac_extensions" , [])
62+ extensions = data .setdefault ("stac_extensions" , [])
6163 if self .extension_url not in extensions :
6264 extensions .append (self .extension_url )
6365
@@ -70,30 +72,30 @@ def transform_json(self, doc: dict[str, Any], scope: Scope) -> dict[str, Any]:
7072 # - Collections
7173 # - Item Properties
7274
73- if "oidc_metadata" not in scope :
75+ if self . state_key not in request . state :
7476 logger .error (
7577 "OIDC metadata not found in scope. "
7678 "Skipping authentication extension."
7779 )
78- return doc
80+ return data
7981
80- scheme_loc = doc ["properties" ] if "properties" in doc else doc
82+ scheme_loc = data ["properties" ] if "properties" in data else data
8183 schemes = scheme_loc .setdefault ("auth:schemes" , {})
8284 schemes [self .auth_scheme_name ] = self .parse_oidc_config (
83- scope . get ("oidc_metadata" , {})
85+ request . state . get (self . state_key , {})
8486 )
8587
8688 # auth:refs
8789 # ---
8890 # Annotate links with "auth:refs": [auth_scheme]
8991 links = chain (
9092 # Item/Collection
91- doc .get ("links" , []),
93+ data .get ("links" , []),
9294 # Collections/Items/Search
9395 (
9496 link
9597 for prop in ["features" , "collections" ]
96- for object_with_links in doc .get (prop , [])
98+ for object_with_links in data .get (prop , [])
9799 for link in object_with_links .get ("links" , [])
98100 ),
99101 )
@@ -111,7 +113,7 @@ def transform_json(self, doc: dict[str, Any], scope: Scope) -> dict[str, Any]:
111113 if match .is_private :
112114 link .setdefault ("auth:refs" , []).append (self .auth_scheme_name )
113115
114- return doc
116+ return data
115117
116118 def parse_oidc_config (self , oidc_config : dict [str , Any ]) -> dict [str , Any ]:
117119 """Parse the OIDC configuration."""
0 commit comments