1616import logging
1717from typing import (
1818 TYPE_CHECKING ,
19+ Any ,
1920 Awaitable ,
2021 Callable ,
2122 Dict ,
2223 Iterable ,
24+ List ,
2325 Mapping ,
2426 Optional ,
2527 Set ,
3436
3537from synapse .api .constants import LoginType
3638from synapse .api .errors import Codes , NotFoundError , RedirectException , SynapseError
39+ from synapse .config .sso import SsoAttributeRequirement
3740from synapse .handlers .ui_auth import UIAuthSessionDataConstants
3841from synapse .http import get_request_user_agent
3942from synapse .http .server import respond_with_html , respond_with_redirect
@@ -893,6 +896,41 @@ def _expire_old_sessions(self):
893896 logger .info ("Expiring mapping session %s" , session_id )
894897 del self ._username_mapping_sessions [session_id ]
895898
899+ def check_required_attributes (
900+ self ,
901+ request : SynapseRequest ,
902+ attributes : Mapping [str , List [Any ]],
903+ attribute_requirements : Iterable [SsoAttributeRequirement ],
904+ ) -> bool :
905+ """
906+ Confirm that the required attributes were present in the SSO response.
907+
908+ If all requirements are met, this will return True.
909+
910+ If any requirement is not met, then the request will be finalized by
911+ showing an error page to the user and False will be returned.
912+
913+ Args:
914+ request: The request to (potentially) respond to.
915+ attributes: The attributes from the SSO IdP.
916+ attribute_requirements: The requirements that attributes must meet.
917+
918+ Returns:
919+ True if all requirements are met, False if any attribute fails to
920+ meet the requirement.
921+
922+ """
923+ # Ensure that the attributes of the logged in user meet the required
924+ # attributes.
925+ for requirement in attribute_requirements :
926+ if not _check_attribute_requirement (attributes , requirement ):
927+ self .render_error (
928+ request , "unauthorised" , "You are not authorised to log in here."
929+ )
930+ return False
931+
932+ return True
933+
896934
897935def get_username_mapping_session_cookie_from_request (request : IRequest ) -> str :
898936 """Extract the session ID from the cookie
@@ -903,3 +941,36 @@ def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
903941 if not session_id :
904942 raise SynapseError (code = 400 , msg = "missing session_id" )
905943 return session_id .decode ("ascii" , errors = "replace" )
944+
945+
946+ def _check_attribute_requirement (
947+ attributes : Mapping [str , List [Any ]], req : SsoAttributeRequirement
948+ ) -> bool :
949+ """Check if SSO attributes meet the proper requirements.
950+
951+ Args:
952+ attributes: A mapping of attributes to an iterable of one or more values.
953+ requirement: The configured requirement to check.
954+
955+ Returns:
956+ True if the required attribute was found and had a proper value.
957+ """
958+ if req .attribute not in attributes :
959+ logger .info ("SSO attribute missing: %s" , req .attribute )
960+ return False
961+
962+ # If the requirement is None, the attribute existing is enough.
963+ if req .value is None :
964+ return True
965+
966+ values = attributes [req .attribute ]
967+ if req .value in values :
968+ return True
969+
970+ logger .info (
971+ "SSO attribute %s did not match required value '%s' (was '%s')" ,
972+ req .attribute ,
973+ req .value ,
974+ values ,
975+ )
976+ return False
0 commit comments