11"""Middleware to enforce authentication."""
22
33import logging
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , field
55from typing import Annotated , Any , Optional , Sequence
66from urllib .parse import urlparse , urlunparse
77
1818logger = logging .getLogger (__name__ )
1919
2020
21+ @dataclass
22+ class OidcService :
23+ """OIDC configuration and JWKS client."""
24+
25+ oidc_config_url : HttpUrl
26+ jwks_client : jwt .PyJWKClient = field (init = False )
27+ metadata : dict [str , Any ] = field (init = False )
28+
29+ def __post_init__ (self ) -> None :
30+ """Initialize OIDC config and JWKS client."""
31+ logger .debug ("Requesting OIDC config" )
32+ origin_url = str (self .oidc_config_url )
33+
34+ try :
35+ response = httpx .get (origin_url )
36+ response .raise_for_status ()
37+ self .metadata = response .json ()
38+ assert self .metadata , "OIDC metadata is empty"
39+
40+ # NOTE: We manually replace the origin of the jwks_uri in the event that
41+ # the jwks_uri is not available from within the proxy.
42+ oidc_url = urlparse (origin_url )
43+ jwks_uri = urlunparse (
44+ urlparse (self .metadata ["jwks_uri" ])._replace (
45+ netloc = oidc_url .netloc , scheme = oidc_url .scheme
46+ )
47+ )
48+ if jwks_uri != self .metadata ["jwks_uri" ]:
49+ logger .warning (
50+ "JWKS URI has been rewritten from %s to %s" ,
51+ self .metadata ["jwks_uri" ],
52+ jwks_uri ,
53+ )
54+ self .jwks_client = jwt .PyJWKClient (jwks_uri )
55+ except httpx .HTTPStatusError as e :
56+ logger .error (
57+ "Received a non-200 response when fetching OIDC config: %s" ,
58+ e .response .text ,
59+ )
60+ raise OidcFetchError (
61+ f"Request for OIDC config failed with status { e .response .status_code } "
62+ ) from e
63+ except httpx .RequestError as e :
64+ logger .error ("Error fetching OIDC config from %s: %s" , origin_url , str (e ))
65+ raise OidcFetchError (f"Request for OIDC config failed: { str (e )} " ) from e
66+
67+
2168@dataclass
2269class EnforceAuthMiddleware :
2370 """Middleware to enforce authentication."""
@@ -26,56 +73,11 @@ class EnforceAuthMiddleware:
2673 private_endpoints : EndpointMethods
2774 public_endpoints : EndpointMethods
2875 default_public : bool
29-
3076 oidc_config_url : HttpUrl
3177 allowed_jwt_audiences : Optional [Sequence [str ]] = None
32-
3378 state_key : str = "payload"
3479
35- # Generated attributes
36- _jwks_client : Optional [jwt .PyJWKClient ] = None
37-
38- @property
39- def jwks_client (self ) -> jwt .PyJWKClient :
40- """Get the OIDC configuration URL."""
41- if not self ._jwks_client :
42- logger .debug ("Requesting OIDC config" )
43- origin_url = str (self .oidc_config_url )
44-
45- try :
46- response = httpx .get (origin_url )
47- response .raise_for_status ()
48- oidc_config = response .json ()
49-
50- # NOTE: We manually replace the origin of the jwks_uri in the event that
51- # the jwks_uri is not available from within the proxy.
52- oidc_url = urlparse (origin_url )
53- jwks_uri = urlunparse (
54- urlparse (oidc_config ["jwks_uri" ])._replace (
55- netloc = oidc_url .netloc , scheme = oidc_url .scheme
56- )
57- )
58- if jwks_uri != oidc_config ["jwks_uri" ]:
59- logger .warning (
60- "JWKS URI has been rewritten from %s to %s" ,
61- oidc_config ["jwks_uri" ],
62- jwks_uri ,
63- )
64- self ._jwks_client = jwt .PyJWKClient (jwks_uri )
65- except httpx .HTTPStatusError as e :
66- logger .error (
67- "Received a non-200 response when fetching OIDC config: %s" ,
68- e .response .text ,
69- )
70- raise OidcFetchError (
71- f"Request for OIDC config failed with status { e .response .status_code } "
72- ) from e
73- except httpx .RequestError as e :
74- logger .error (
75- "Error fetching OIDC config from %s: %s" , origin_url , str (e )
76- )
77- raise OidcFetchError (f"Request for OIDC config failed: { str (e )} " ) from e
78- return self ._jwks_client
80+ _oidc_config : Optional [OidcService ] = None
7981
8082 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
8183 """Enforce authentication."""
@@ -107,6 +109,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
107109 self .state_key ,
108110 payload ,
109111 )
112+ setattr (request .state , "oidc_metadata" , self .oidc_config .metadata )
110113 return await self .app (scope , receive , send )
111114
112115 def validate_token (
@@ -137,7 +140,7 @@ def validate_token(
137140
138141 # Parse & validate token
139142 try :
140- key = self .jwks_client .get_signing_key_from_jwt (token ).key
143+ key = self .oidc_config . jwks_client .get_signing_key_from_jwt (token ).key
141144 payload = jwt .decode (
142145 token ,
143146 key ,
@@ -163,6 +166,13 @@ def validate_token(
163166 )
164167 return payload
165168
169+ @property
170+ def oidc_config (self ) -> OidcService :
171+ """Get the OIDC configuration."""
172+ if not self ._oidc_config :
173+ self ._oidc_config = OidcService (oidc_config_url = self .oidc_config_url )
174+ return self ._oidc_config
175+
166176
167177class OidcFetchError (Exception ):
168178 """Error fetching OIDC configuration."""
0 commit comments