11import httpx
2+ import jwt
23from cachetools import TTLCache , cached
34from fastapi import Security
45from fastapi .security import OAuth2AuthorizationCodeBearer
5- from jose import JWTError , jwt
66
77from authentication .mock_token_generator import mock_rsa_public_key
88from authentication .models import User
1818
1919
2020@cached (cache = TTLCache (maxsize = 32 , ttl = 86400 ))
21- def fetch_openid_configuration () -> dict [ str , str ] :
21+ def fetch_openid_configuration () -> jwt . PyJWKClient :
2222 try :
2323 oid_conf_response = httpx .get (config .OAUTH_WELL_KNOWN )
2424 oid_conf_response .raise_for_status ()
2525 oid_conf = oid_conf_response .json ()
26- json_web_key_set_response = httpx .get (oid_conf ["jwks_uri" ])
27- json_web_key_set_response .raise_for_status ()
28- return {
29- "authorization_endpoint" : oid_conf ["authorization_endpoint" ],
30- "token_endpoint" : oid_conf ["token_endpoint" ],
31- "jwks" : json_web_key_set_response .json ()["keys" ],
32- }
26+ return jwt .PyJWKClient (oid_conf ["jwks_uri" ])
3327 except Exception as error :
3428 logger .error (f"Failed to fetch OpenId Connect configuration for '{ config .OAUTH_WELL_KNOWN } ': { error } " )
3529 raise credentials_exception
@@ -41,7 +35,11 @@ def auth_with_jwt(jwt_token: str = Security(oauth2_scheme)) -> User:
4135 if not jwt_token :
4236 raise credentials_exception
4337 # If TEST_TOKEN is true, we are running tests. Use the self-signed keys. If not, get keys from auth server.
44- key = mock_rsa_public_key if config .TEST_TOKEN else {"keys" : fetch_openid_configuration ()["jwks" ]}
38+ key = (
39+ mock_rsa_public_key
40+ if config .TEST_TOKEN
41+ else fetch_openid_configuration ().get_signing_key_from_jwt (jwt_token ).key
42+ )
4543
4644 try :
4745 payload = jwt .decode (jwt_token , key , algorithms = ["RS256" ], audience = config .OAUTH_AUDIENCE )
@@ -50,7 +48,7 @@ def auth_with_jwt(jwt_token: str = Security(oauth2_scheme)) -> User:
5048 user = User (user_id = payload ["oid" ], ** payload )
5149 else :
5250 user = User (user_id = payload ["sub" ], ** payload )
53- except JWTError as error :
51+ except jwt . exceptions . InvalidTokenError as error :
5452 logger .warning (f"Failed to decode JWT: { error } " )
5553 raise credentials_exception
5654
0 commit comments