11"""Middleware to enforce authentication."""
22
3- import json
43import logging
5- import urllib .request
64from dataclasses import dataclass , field
75from typing import Annotated , Optional , Sequence
86
7+ import httpx
98import jwt
109from fastapi import HTTPException , Request , Security , status
1110from pydantic import HttpUrl
@@ -28,7 +27,7 @@ class EnforceAuthMiddleware:
2827 default_public : bool
2928
3029 oidc_config_url : HttpUrl
31- openid_configuration_internal_url : Optional [HttpUrl ] = None
30+ oidc_config_internal_url : Optional [HttpUrl ] = None
3231 allowed_jwt_audiences : Optional [Sequence [str ]] = None
3332
3433 state_key : str = "user"
@@ -39,18 +38,24 @@ class EnforceAuthMiddleware:
3938 def __post_init__ (self ):
4039 """Initialize the OIDC authentication class."""
4140 logger .debug ("Requesting OIDC config" )
42- origin_url = str (self .openid_configuration_internal_url or self .oidc_config_url )
43- with urllib .request .urlopen (origin_url ) as response :
44- if response .status != 200 :
45- logger .error (
46- "Received a non-200 response when fetching OIDC config: %s" ,
47- response .text ,
48- )
49- raise OidcFetchError (
50- f"Request for OIDC config failed with status { response .status } "
51- )
52- oidc_config = json .load (response )
41+ origin_url = str (self .oidc_config_internal_url or self .oidc_config_url )
42+
43+ try :
44+ response = httpx .get (origin_url )
45+ response .raise_for_status ()
46+ oidc_config = response .json ()
5347 self .jwks_client = jwt .PyJWKClient (oidc_config ["jwks_uri" ])
48+ except httpx .HTTPStatusError as e :
49+ logger .error (
50+ "Received a non-200 response when fetching OIDC config: %s" ,
51+ e .response .text ,
52+ )
53+ raise OidcFetchError (
54+ f"Request for OIDC config failed with status { e .response .status_code } "
55+ )
56+ except httpx .RequestError as e :
57+ logger .error ("Error fetching OIDC config from %s: %s" , origin_url , str (e ))
58+ raise OidcFetchError (f"Request for OIDC config failed: { str (e )} " )
5459
5560 async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
5661 """Enforce authentication."""
0 commit comments