3737from synapse .http .server import respond_with_html
3838from synapse .http .site import SynapseRequest
3939from synapse .logging .context import make_deferred_yieldable
40- from synapse .types import UserID , map_username_to_mxid_localpart
40+ from synapse .types import JsonDict , UserID , map_username_to_mxid_localpart
4141from synapse .util import json_decoder
4242
4343if TYPE_CHECKING :
@@ -707,14 +707,23 @@ async def handle_oidc_callback(self, request: SynapseRequest) -> None:
707707 self ._render_error (request , "mapping_error" , str (e ))
708708 return
709709
710+ # Mapping providers might not have get_extra_attributes: only call this
711+ # method if it exists.
712+ extra_attributes = None
713+ get_extra_attributes = getattr (
714+ self ._user_mapping_provider , "get_extra_attributes" , None
715+ )
716+ if get_extra_attributes :
717+ extra_attributes = await get_extra_attributes (userinfo , token )
718+
710719 # and finally complete the login
711720 if ui_auth_session_id :
712721 await self ._auth_handler .complete_sso_ui_auth (
713722 user_id , ui_auth_session_id , request
714723 )
715724 else :
716725 await self ._auth_handler .complete_sso_login (
717- user_id , request , client_redirect_url
726+ user_id , request , client_redirect_url , extra_attributes
718727 )
719728
720729 def _generate_oidc_session_token (
@@ -984,7 +993,7 @@ def get_remote_user_id(self, userinfo: UserInfo) -> str:
984993 async def map_user_attributes (
985994 self , userinfo : UserInfo , token : Token
986995 ) -> UserAttribute :
987- """Map a `` UserInfo`` objects into user attributes.
996+ """Map a `UserInfo` object into user attributes.
988997
989998 Args:
990999 userinfo: An object representing the user given by the OIDC provider
@@ -995,6 +1004,18 @@ async def map_user_attributes(
9951004 """
9961005 raise NotImplementedError ()
9971006
1007+ async def get_extra_attributes (self , userinfo : UserInfo , token : Token ) -> JsonDict :
1008+ """Map a `UserInfo` object into additional attributes passed to the client during login.
1009+
1010+ Args:
1011+ userinfo: An object representing the user given by the OIDC provider
1012+ token: A dict with the tokens returned by the provider
1013+
1014+ Returns:
1015+ A dict containing additional attributes. Must be JSON serializable.
1016+ """
1017+ return {}
1018+
9981019
9991020# Used to clear out "None" values in templates
10001021def jinja_finalize (thing ):
@@ -1009,6 +1030,7 @@ class JinjaOidcMappingConfig:
10091030 subject_claim = attr .ib () # type: str
10101031 localpart_template = attr .ib () # type: Template
10111032 display_name_template = attr .ib () # type: Optional[Template]
1033+ extra_attributes = attr .ib () # type: Dict[str, Template]
10121034
10131035
10141036class JinjaOidcMappingProvider (OidcMappingProvider [JinjaOidcMappingConfig ]):
@@ -1047,10 +1069,28 @@ def parse_config(config: dict) -> JinjaOidcMappingConfig:
10471069 % (e ,)
10481070 )
10491071
1072+ extra_attributes = {} # type Dict[str, Template]
1073+ if "extra_attributes" in config :
1074+ extra_attributes_config = config .get ("extra_attributes" ) or {}
1075+ if not isinstance (extra_attributes_config , dict ):
1076+ raise ConfigError (
1077+ "oidc_config.user_mapping_provider.config.extra_attributes must be a dict"
1078+ )
1079+
1080+ for key , value in extra_attributes_config .items ():
1081+ try :
1082+ extra_attributes [key ] = env .from_string (value )
1083+ except Exception as e :
1084+ raise ConfigError (
1085+ "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r"
1086+ % (key , e )
1087+ )
1088+
10501089 return JinjaOidcMappingConfig (
10511090 subject_claim = subject_claim ,
10521091 localpart_template = localpart_template ,
10531092 display_name_template = display_name_template ,
1093+ extra_attributes = extra_attributes ,
10541094 )
10551095
10561096 def get_remote_user_id (self , userinfo : UserInfo ) -> str :
@@ -1071,3 +1111,13 @@ async def map_user_attributes(
10711111 display_name = None
10721112
10731113 return UserAttribute (localpart = localpart , display_name = display_name )
1114+
1115+ async def get_extra_attributes (self , userinfo : UserInfo , token : Token ) -> JsonDict :
1116+ extras = {} # type: Dict[str, str]
1117+ for key , template in self ._config .extra_attributes .items ():
1118+ try :
1119+ extras [key ] = template .render (user = userinfo ).strip ()
1120+ except Exception as e :
1121+ # Log an error and skip this value (don't break login for this).
1122+ logger .error ("Failed to render OIDC extra attribute %s: %s" % (key , e ))
1123+ return extras
0 commit comments