11"""Middleware to enforce authentication."""
22
3- import json
43import logging
5- import urllib .request
6- from dataclasses import dataclass , field
7- from typing import Annotated , Optional , Sequence
4+ from dataclasses import dataclass
5+ from typing import Annotated , Any , Optional , Sequence
86
7+ import httpx
98import jwt
109from fastapi import HTTPException , Request , Security , status
1110from pydantic import HttpUrl
@@ -28,29 +27,40 @@ class EnforceAuthMiddleware:
2827 default_public : bool
2928
3029 oidc_config_url : HttpUrl
31- openid_configuration_internal_url : Optional [HttpUrl ] = None
30+ oidc_config_internal_url : Optional [HttpUrl ] = None
3231 allowed_jwt_audiences : Optional [Sequence [str ]] = None
3332
34- state_key : str = "user "
33+ state_key : str = "payload "
3534
3635 # Generated attributes
37- jwks_client : jwt .PyJWKClient = field (init = False )
38-
39- def __post_init__ (self ):
40- """Initialize the OIDC authentication class."""
41- logger .debug ("Requesting OIDC config" )
42- origin_url = str (self .openid_configuration_internal_url or self .oidc_config_url )
43- with urllib .request .urlopen (origin_url ) as response :
44- if response .status != 200 :
36+ _jwks_client : Optional [jwt .PyJWKClient ] = None
37+
38+ @property
39+ def jwks_client (self ) -> jwt .PyJWKClient :
40+ """Get the OIDC configuration URL."""
41+ if not self ._jwks_client :
42+ logger .debug ("Requesting OIDC config" )
43+ origin_url = str (self .oidc_config_internal_url or self .oidc_config_url )
44+
45+ try :
46+ response = httpx .get (origin_url )
47+ response .raise_for_status ()
48+ oidc_config = response .json ()
49+ self ._jwks_client = jwt .PyJWKClient (oidc_config ["jwks_uri" ])
50+ except httpx .HTTPStatusError as e :
4551 logger .error (
4652 "Received a non-200 response when fetching OIDC config: %s" ,
47- response .text ,
53+ e . response .text ,
4854 )
4955 raise OidcFetchError (
50- f"Request for OIDC config failed with status { response .status } "
56+ f"Request for OIDC config failed with status { e .response .status_code } "
57+ ) from e
58+ except httpx .RequestError as e :
59+ logger .error (
60+ "Error fetching OIDC config from %s: %s" , origin_url , str (e )
5161 )
52- oidc_config = json . load ( response )
53- self .jwks_client = jwt . PyJWKClient ( oidc_config [ "jwks_uri" ])
62+ raise OidcFetchError ( f"Request for OIDC config failed: { str ( e ) } " ) from e
63+ return self ._jwks_client
5464
5565 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
5666 """Enforce authentication."""
@@ -59,17 +69,20 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
5969
6070 request = Request (scope )
6171 try :
62- setattr (
63- request .state ,
64- self .state_key ,
65- self .validated_user (
66- request .headers .get ("Authorization" ),
67- auto_error = self .should_enforce_auth (request ),
68- ),
72+ payload = self .validate_token (
73+ request .headers .get ("Authorization" ),
74+ auto_error = self .should_enforce_auth (request ),
6975 )
7076 except HTTPException as e :
7177 response = JSONResponse ({"detail" : e .detail }, status_code = e .status_code )
7278 return await response (scope , receive , send )
79+
80+ # Set the payload in the request state
81+ setattr (
82+ request .state ,
83+ self .state_key ,
84+ payload ,
85+ )
7386 return await self .app (scope , receive , send )
7487
7588 def should_enforce_auth (self , request : Request ) -> bool :
@@ -80,11 +93,11 @@ def should_enforce_auth(self, request: Request) -> bool:
8093 # If not default_public, we enforce auth if the request is not for an endpoint explicitly listed as public
8194 return not matches_route (request , self .public_endpoints )
8295
83- def validated_user (
96+ def validate_token (
8497 self ,
8598 auth_header : Annotated [str , Security (...)],
8699 auto_error : bool = True ,
87- ):
100+ ) -> Optional [ dict [ str , Any ]] :
88101 """Dependency to validate an OIDC token."""
89102 if not auth_header :
90103 if auto_error :
0 commit comments