22
33import logging
44from dataclasses import dataclass , field
5- from typing import Annotated , Optional , Sequence
5+ from typing import Annotated , Any , Optional , Sequence
66
77import httpx
88import jwt
@@ -30,32 +30,37 @@ class EnforceAuthMiddleware:
3030 oidc_config_internal_url : Optional [HttpUrl ] = None
3131 allowed_jwt_audiences : Optional [Sequence [str ]] = None
3232
33- state_key : str = "user "
33+ state_key : str = "payload "
3434
3535 # Generated attributes
36- jwks_client : jwt .PyJWKClient = field (init = False )
37-
38- def __post_init__ (self ):
39- """Initialize the OIDC authentication class."""
40- logger .debug ("Requesting OIDC config" )
41- origin_url = str (self .oidc_config_internal_url or self .oidc_config_url )
42-
43- try :
44- response = httpx .get (origin_url )
45- response .raise_for_status ()
46- oidc_config = response .json ()
47- self .jwks_client = jwt .PyJWKClient (oidc_config ["jwks_uri" ])
48- except httpx .HTTPStatusError as e :
49- logger .error (
50- "Received a non-200 response when fetching OIDC config: %s" ,
51- e .response .text ,
52- )
53- raise OidcFetchError (
54- f"Request for OIDC config failed with status { e .response .status_code } "
55- )
56- except httpx .RequestError as e :
57- logger .error ("Error fetching OIDC config from %s: %s" , origin_url , str (e ))
58- raise OidcFetchError (f"Request for OIDC config failed: { str (e )} " )
36+ _jwks_client : Optional [jwt .PyJWKClient ] = None
37+
38+ @property
39+ def jwks_client (self ) -> HttpUrl :
40+ """Get the OIDC configuration URL."""
41+ if not self ._jwks_client :
42+ logger .debug ("Requesting OIDC config" )
43+ origin_url = str (self .oidc_config_internal_url or self .oidc_config_url )
44+
45+ try :
46+ response = httpx .get (origin_url )
47+ response .raise_for_status ()
48+ oidc_config = response .json ()
49+ self ._jwks_client = jwt .PyJWKClient (oidc_config ["jwks_uri" ])
50+ except httpx .HTTPStatusError as e :
51+ logger .error (
52+ "Received a non-200 response when fetching OIDC config: %s" ,
53+ e .response .text ,
54+ )
55+ raise OidcFetchError (
56+ f"Request for OIDC config failed with status { e .response .status_code } "
57+ ) from e
58+ except httpx .RequestError as e :
59+ logger .error (
60+ "Error fetching OIDC config from %s: %s" , origin_url , str (e )
61+ )
62+ raise OidcFetchError (f"Request for OIDC config failed: { str (e )} " ) from e
63+ return self ._jwks_client
5964
6065 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
6166 """Enforce authentication."""
@@ -64,17 +69,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
6469
6570 request = Request (scope )
6671 try :
67- setattr (
68- request .state ,
69- self .state_key ,
70- self .validated_user (
71- request .headers .get ("Authorization" ),
72- auto_error = self .should_enforce_auth (request ),
73- ),
72+ payload = self .validate_token (
73+ request .headers .get ("Authorization" ),
74+ auto_error = self .should_enforce_auth (request ),
7475 )
7576 except HTTPException as e :
7677 response = JSONResponse ({"detail" : e .detail }, status_code = e .status_code )
7778 return await response (scope , receive , send )
79+
80+ # Set the payload in the request state
81+ setattr (
82+ request .state ,
83+ self .state_key ,
84+ payload ,
85+ )
7886 return await self .app (scope , receive , send )
7987
8088 def should_enforce_auth (self , request : Request ) -> bool :
@@ -85,11 +93,11 @@ def should_enforce_auth(self, request: Request) -> bool:
8593 # If not default_public, we enforce auth if the request is not for an endpoint explicitly listed as public
8694 return not matches_route (request , self .public_endpoints )
8795
88- def validated_user (
96+ def validate_token (
8997 self ,
9098 auth_header : Annotated [str , Security (...)],
9199 auto_error : bool = True ,
92- ):
100+ ) -> Optional [ dict [ str , Any ]] :
93101 """Dependency to validate an OIDC token."""
94102 if not auth_header :
95103 if auto_error :
0 commit comments