2323from authlib .jose import JsonWebToken , jwt
2424from authlib .oauth2 .auth import ClientAuth
2525from authlib .oauth2 .rfc6749 .parameters import prepare_grant_uri
26- from authlib .oidc .core import CodeIDToken , ImplicitIDToken , UserInfo
26+ from authlib .oidc .core import CodeIDToken , UserInfo
2727from authlib .oidc .discovery import OpenIDProviderMetadata , get_well_known_url
2828from jinja2 import Environment , Template
2929from pymacaroons .exceptions import (
@@ -117,7 +117,8 @@ async def load_metadata(self) -> None:
117117 for idp_id , p in self ._providers .items ():
118118 try :
119119 await p .load_metadata ()
120- await p .load_jwks ()
120+ if not p ._uses_userinfo :
121+ await p .load_jwks ()
121122 except Exception as e :
122123 raise Exception (
123124 "Error while initialising OIDC provider %r" % (idp_id ,)
@@ -498,10 +499,6 @@ async def load_jwks(self, force: bool = False) -> JWKS:
498499 return await self ._jwks .get ()
499500
500501 async def _load_jwks (self ) -> JWKS :
501- if self ._uses_userinfo :
502- # We're not using jwt signing, return an empty jwk set
503- return {"keys" : []}
504-
505502 metadata = await self .load_metadata ()
506503
507504 # Load the JWKS using the `jwks_uri` metadata.
@@ -663,7 +660,7 @@ async def _fetch_userinfo(self, token: Token) -> UserInfo:
663660
664661 return UserInfo (resp )
665662
666- async def _parse_id_token (self , token : Token , nonce : str ) -> UserInfo :
663+ async def _parse_id_token (self , token : Token , nonce : str ) -> CodeIDToken :
667664 """Return an instance of UserInfo from token's ``id_token``.
668665
669666 Args:
@@ -673,7 +670,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
673670 request. This value should match the one inside the token.
674671
675672 Returns:
676- An object representing the user .
673+ The decoded claims in the ID token .
677674 """
678675 metadata = await self .load_metadata ()
679676 claims_params = {
@@ -684,9 +681,6 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
684681 # If we got an `access_token`, there should be an `at_hash` claim
685682 # in the `id_token` that we can check against.
686683 claims_params ["access_token" ] = token ["access_token" ]
687- claims_cls = CodeIDToken
688- else :
689- claims_cls = ImplicitIDToken
690684
691685 alg_values = metadata .get ("id_token_signing_alg_values_supported" , ["RS256" ])
692686 jwt = JsonWebToken (alg_values )
@@ -703,7 +697,7 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
703697 claims = jwt .decode (
704698 id_token ,
705699 key = jwk_set ,
706- claims_cls = claims_cls ,
700+ claims_cls = CodeIDToken ,
707701 claims_options = claim_options ,
708702 claims_params = claims_params ,
709703 )
@@ -713,15 +707,16 @@ async def _parse_id_token(self, token: Token, nonce: str) -> UserInfo:
713707 claims = jwt .decode (
714708 id_token ,
715709 key = jwk_set ,
716- claims_cls = claims_cls ,
710+ claims_cls = CodeIDToken ,
717711 claims_options = claim_options ,
718712 claims_params = claims_params ,
719713 )
720714
721715 logger .debug ("Decoded id_token JWT %r; validating" , claims )
722716
723717 claims .validate (leeway = 120 ) # allows 2 min of clock skew
724- return UserInfo (claims )
718+
719+ return claims
725720
726721 async def handle_redirect_request (
727722 self ,
@@ -837,22 +832,37 @@ async def handle_oidc_callback(
837832
838833 logger .debug ("Successfully obtained OAuth2 token data: %r" , token )
839834
840- # Now that we have a token, get the userinfo, either by decoding the
841- # `id_token` or by fetching the `userinfo_endpoint`.
835+ # If there is an id_token, it should be validated, regardless of the
836+ # userinfo endpoint is used or not.
837+ if token .get ("id_token" ) is not None :
838+ try :
839+ id_token = await self ._parse_id_token (token , nonce = session_data .nonce )
840+ sid = id_token .get ("sid" )
841+ except Exception as e :
842+ logger .exception ("Invalid id_token" )
843+ self ._sso_handler .render_error (request , "invalid_token" , str (e ))
844+ return
845+ else :
846+ id_token = None
847+ sid = None
848+
849+ # Now that we have a token, get the userinfo either from the `id_token`
850+ # claims or by fetching the `userinfo_endpoint`.
842851 if self ._uses_userinfo :
843852 try :
844853 userinfo = await self ._fetch_userinfo (token )
845854 except Exception as e :
846855 logger .exception ("Could not fetch userinfo" )
847856 self ._sso_handler .render_error (request , "fetch_error" , str (e ))
848857 return
858+ elif id_token is not None :
859+ userinfo = UserInfo (id_token )
849860 else :
850- try :
851- userinfo = await self ._parse_id_token (token , nonce = session_data .nonce )
852- except Exception as e :
853- logger .exception ("Invalid id_token" )
854- self ._sso_handler .render_error (request , "invalid_token" , str (e ))
855- return
861+ logger .error ("Missing id_token in token response" )
862+ self ._sso_handler .render_error (
863+ request , "invalid_token" , "Missing id_token in token response"
864+ )
865+ return
856866
857867 # first check if we're doing a UIA
858868 if session_data .ui_auth_session_id :
@@ -884,7 +894,7 @@ async def handle_oidc_callback(
884894 # Call the mapper to register/login the user
885895 try :
886896 await self ._complete_oidc_login (
887- userinfo , token , request , session_data .client_redirect_url
897+ userinfo , token , request , session_data .client_redirect_url , sid
888898 )
889899 except MappingException as e :
890900 logger .exception ("Could not map user" )
@@ -896,6 +906,7 @@ async def _complete_oidc_login(
896906 token : Token ,
897907 request : SynapseRequest ,
898908 client_redirect_url : str ,
909+ sid : Optional [str ],
899910 ) -> None :
900911 """Given a UserInfo response, complete the login flow
901912
@@ -1008,6 +1019,7 @@ async def grandfather_existing_users() -> Optional[str]:
10081019 oidc_response_to_user_attributes ,
10091020 grandfather_existing_users ,
10101021 extra_attributes ,
1022+ auth_provider_session_id = sid ,
10111023 )
10121024
10131025 def _remote_id_from_userinfo (self , userinfo : UserInfo ) -> str :
0 commit comments