From 8ce02da65d43ba4675d1d1b0f4c8e72536240636 Mon Sep 17 00:00:00 2001 From: Dan Watson Date: Thu, 21 Jun 2018 16:00:24 -0400 Subject: [PATCH] Support multiple assertion consumer services (#102) --- src/onelogin/saml2/auth.py | 7 ++- src/onelogin/saml2/authn_request.py | 21 ++++++- src/onelogin/saml2/metadata.py | 18 +++++- src/onelogin/saml2/settings.py | 28 +++++++-- src/onelogin/saml2/xml_templates.py | 9 ++- tests/settings/settings9.json | 58 +++++++++++++++++++ .../saml2_tests/authn_request_test.py | 11 ++++ 7 files changed, 137 insertions(+), 15 deletions(-) create mode 100644 tests/settings/settings9.json diff --git a/src/onelogin/saml2/auth.py b/src/onelogin/saml2/auth.py index df732701..ff0a69b9 100644 --- a/src/onelogin/saml2/auth.py +++ b/src/onelogin/saml2/auth.py @@ -318,7 +318,7 @@ def get_last_assertion_id(self): """ return self.__last_assertion_id - def login(self, return_to=None, force_authn=False, is_passive=False, set_nameid_policy=True): + def login(self, return_to=None, force_authn=False, is_passive=False, set_nameid_policy=True, acs_index=None): """ Initiates the SSO process. @@ -334,10 +334,13 @@ def login(self, return_to=None, force_authn=False, is_passive=False, set_nameid_ :param set_nameid_policy: Optional argument. When true the AuthNRequest will set a nameIdPolicy element. :type set_nameid_policy: bool + :param acs_index: Optional argument. The index of the assertionConsumerService to use, if multiple were specified. + :type acs_index: int + :returns: Redirection URL :rtype: string """ - authn_request = OneLogin_Saml2_Authn_Request(self.__settings, force_authn, is_passive, set_nameid_policy) + authn_request = OneLogin_Saml2_Authn_Request(self.__settings, force_authn, is_passive, set_nameid_policy, acs_index) self.__last_request = authn_request.get_xml() self.__last_request_id = authn_request.get_id() diff --git a/src/onelogin/saml2/authn_request.py b/src/onelogin/saml2/authn_request.py index 2a1eda88..104dd761 100644 --- a/src/onelogin/saml2/authn_request.py +++ b/src/onelogin/saml2/authn_request.py @@ -9,6 +9,7 @@ """ +from onelogin.saml2 import compat from onelogin.saml2.constants import OneLogin_Saml2_Constants from onelogin.saml2.utils import OneLogin_Saml2_Utils from onelogin.saml2.xml_templates import OneLogin_Saml2_Templates @@ -22,7 +23,7 @@ class OneLogin_Saml2_Authn_Request(object): """ - def __init__(self, settings, force_authn=False, is_passive=False, set_nameid_policy=True): + def __init__(self, settings, force_authn=False, is_passive=False, set_nameid_policy=True, acs_index=None): """ Constructs the AuthnRequest object. @@ -37,6 +38,9 @@ def __init__(self, settings, force_authn=False, is_passive=False, set_nameid_pol :param set_nameid_policy: Optional argument. When true the AuthNRequest will set a nameIdPolicy element. :type set_nameid_policy: bool + + :param acs_index: Optional argument. The index of the assertionConsumerService to use, if multiple were specified. + :type acs_index: int """ self.__settings = settings @@ -102,6 +106,19 @@ def __init__(self, settings, force_authn=False, is_passive=False, set_nameid_pol if 'attributeConsumingService' in sp_data and sp_data['attributeConsumingService']: attr_consuming_service_str = "\n AttributeConsumingServiceIndex=\"1\"" + assertion_url = '' + if isinstance(sp_data['assertionConsumerService'], dict): + assertion_url = sp_data['assertionConsumerService']['url'] + else: + for idx, acs in enumerate(sp_data['assertionConsumerService']): + if idx == 0: + # By default, use the first assertion consumer service if an index is not specified. + assertion_url = acs['url'] + index = compat.to_string(acs.get('index', idx)) + if index == compat.to_string(acs_index): + assertion_url = acs['url'] + break + request = OneLogin_Saml2_Templates.AUTHN_REQUEST % \ { 'id': uid, @@ -110,7 +127,7 @@ def __init__(self, settings, force_authn=False, is_passive=False, set_nameid_pol 'is_passive_str': is_passive_str, 'issue_instant': issue_instant, 'destination': destination, - 'assertion_url': sp_data['assertionConsumerService']['url'], + 'assertion_url': assertion_url, 'entity_id': sp_data['entityId'], 'nameid_policy_str': nameid_policy_str, 'requested_authn_context_str': requested_authn_context_str, diff --git a/src/onelogin/saml2/metadata.py b/src/onelogin/saml2/metadata.py index 9d58e5b7..70fee036 100644 --- a/src/onelogin/saml2/metadata.py +++ b/src/onelogin/saml2/metadata.py @@ -173,6 +173,21 @@ def builder(sp, authnsign=False, wsign=False, valid_until=None, cache_duration=N 'requested_attribute_str': '\n'.join(requested_attribute_data) } + str_assertion_consumers = '' + if isinstance(sp['assertionConsumerService'], dict): + str_assertion_consumers += OneLogin_Saml2_Templates.MD_ASSERTION_CONSUMER_SERVICE % { + 'binding': sp['assertionConsumerService']['binding'], + 'location': sp['assertionConsumerService']['url'], + 'index': sp['assertionConsumerService'].get('index', '1') + } + else: + for idx, acs in enumerate(sp['assertionConsumerService']): + str_assertion_consumers += OneLogin_Saml2_Templates.MD_ASSERTION_CONSUMER_SERVICE % { + 'binding': acs['binding'], + 'location': acs['url'], + 'index': acs.get('index', compat.to_string(idx)) + } + metadata = OneLogin_Saml2_Templates.MD_ENTITY_DESCRIPTOR % \ { 'valid': ('validUntil="%s"' % valid_until_str) if valid_until_str else '', @@ -181,8 +196,7 @@ def builder(sp, authnsign=False, wsign=False, valid_until=None, cache_duration=N 'authnsign': str_authnsign, 'wsign': str_wsign, 'name_id_format': sp['NameIDFormat'], - 'binding': sp['assertionConsumerService']['binding'], - 'location': sp['assertionConsumerService']['url'], + 'assertion_consumers': str_assertion_consumers, 'sls': sls, 'organization': str_organization, 'contacts': str_contacts, diff --git a/src/onelogin/saml2/settings.py b/src/onelogin/saml2/settings.py index 4dd6b220..ffd2eefc 100644 --- a/src/onelogin/saml2/settings.py +++ b/src/onelogin/saml2/settings.py @@ -258,8 +258,12 @@ def __add_default_values(self): """ Add default values if the settings info is not complete """ - self.__sp.setdefault('assertionConsumerService', {}) - self.__sp['assertionConsumerService'].setdefault('binding', OneLogin_Saml2_Constants.BINDING_HTTP_POST) + acs = self.__sp.setdefault('assertionConsumerService', {}) + if isinstance(acs, dict): + acs.setdefault('binding', OneLogin_Saml2_Constants.BINDING_HTTP_POST) + else: + for entry in acs: + entry.setdefault('binding', OneLogin_Saml2_Constants.BINDING_HTTP_POST) self.__sp.setdefault('attributeConsumingService', {}) @@ -415,10 +419,22 @@ def check_sp_settings(self, settings): if not sp.get('entityId'): errors.append('sp_entityId_not_found') - if not sp.get('assertionConsumerService', {}).get('url'): - errors.append('sp_acs_not_found') - elif not validate_url(sp['assertionConsumerService']['url']): - errors.append('sp_acs_url_invalid') + acs_list = sp.get('assertionConsumerService', {}) + acs_indexes = set() + if isinstance(acs_list, dict): + acs_list = [acs_list] + if not isinstance(acs_list, list): + errors.append('sp_acs_invalid_type') + else: + for idx, acs in enumerate(acs_list): + index = compat.to_string(acs.get('index', idx)) + if index in acs_indexes: + errors.append('sp_acs_duplicate_index') + acs_indexes.add(index) + if not acs.get('url'): + errors.append('sp_acs_not_found') + elif not validate_url(acs['url']): + errors.append('sp_acs_url_invalid') if sp.get('attributeConsumingService'): attributeConsumingService = sp['attributeConsumingService'] diff --git a/src/onelogin/saml2/xml_templates.py b/src/onelogin/saml2/xml_templates.py index 99025546..54330fab 100644 --- a/src/onelogin/saml2/xml_templates.py +++ b/src/onelogin/saml2/xml_templates.py @@ -80,6 +80,11 @@ class OneLogin_Saml2_Templates(object): %(attr_cs_desc)s%(requested_attribute_str)s \n""" + MD_ASSERTION_CONSUMER_SERVICE = """\ + \n""" + MD_ENTITY_DESCRIPTOR = """\ %(sls)s %(name_id_format)s - + %(assertion_consumers)s %(attribute_consuming_service)s %(organization)s %(contacts)s diff --git a/tests/settings/settings9.json b/tests/settings/settings9.json new file mode 100644 index 00000000..7ed1e47e --- /dev/null +++ b/tests/settings/settings9.json @@ -0,0 +1,58 @@ +{ + "strict": false, + "debug": false, + "custom_base_path": "../../../tests/data/customPath/", + "sp": { + "entityId": "http://stuff.com/endpoints/metadata.php", + "assertionConsumerService": [ + { + "url": "http://stuff.com/endpoints/endpoints/acs.php", + "index": "123" + }, + { + "url": "http://stuff.com/endpoints/endpoints/acs2.php", + "index": "456" + }, + { + "url": "http://stuff.com/endpoints/endpoints/acs3.php", + "index": "789" + } + ], + "singleLogoutService": { + "url": "http://stuff.com/endpoints/endpoints/sls.php" + }, + "NameIDFormat": "urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified" + }, + "idp": { + "entityId": "http://idp.example.com/", + "singleSignOnService": { + "url": "http://idp.example.com/SSOService.php" + }, + "singleLogoutService": { + "url": "http://idp.example.com/SingleLogoutService.php" + }, + "x509cert": "MIICgTCCAeoCCQCbOlrWDdX7FTANBgkqhkiG9w0BAQUFADCBhDELMAkGA1UEBhMCTk8xGDAWBgNVBAgTD0FuZHJlYXMgU29sYmVyZzEMMAoGA1UEBxMDRm9vMRAwDgYDVQQKEwdVTklORVRUMRgwFgYDVQQDEw9mZWlkZS5lcmxhbmcubm8xITAfBgkqhkiG9w0BCQEWEmFuZHJlYXNAdW5pbmV0dC5ubzAeFw0wNzA2MTUxMjAxMzVaFw0wNzA4MTQxMjAxMzVaMIGEMQswCQYDVQQGEwJOTzEYMBYGA1UECBMPQW5kcmVhcyBTb2xiZXJnMQwwCgYDVQQHEwNGb28xEDAOBgNVBAoTB1VOSU5FVFQxGDAWBgNVBAMTD2ZlaWRlLmVybGFuZy5ubzEhMB8GCSqGSIb3DQEJARYSYW5kcmVhc0B1bmluZXR0Lm5vMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDivbhR7P516x/S3BqKxupQe0LONoliupiBOesCO3SHbDrl3+q9IbfnfmE04rNuMcPsIxB161TdDpIesLCn7c8aPHISKOtPlAeTZSnb8QAu7aRjZq3+PbrP5uW3TcfCGPtKTytHOge/OlJbo078dVhXQ14d1EDwXJW1rRXuUt4C8QIDAQABMA0GCSqGSIb3DQEBBQUAA4GBACDVfp86HObqY+e8BUoWQ9+VMQx1ASDohBjwOsg2WykUqRXF+dLfcUH9dWR63CtZIKFDbStNomPnQz7nbK+onygwBspVEbnHuUihZq3ZUdmumQqCw4Uvs/1Uvq3orOo/WJVhTyvLgFVK2QarQ4/67OZfHd7R+POBXhophSMv1ZOo" + }, + "security": { + "authnRequestsSigned": false, + "wantAssertionsSigned": false, + "signMetadata": false + }, + "contactPerson": { + "technical": { + "givenName": "technical_name", + "emailAddress": "technical@example.com" + }, + "support": { + "givenName": "support_name", + "emailAddress": "support@example.com" + } + }, + "organization": { + "en-US": { + "name": "sp_test", + "displayname": "SP test", + "url": "http://sp.example.com" + } + } +} diff --git a/tests/src/OneLogin/saml2_tests/authn_request_test.py b/tests/src/OneLogin/saml2_tests/authn_request_test.py index c0dcbf2c..8a4d909f 100644 --- a/tests/src/OneLogin/saml2_tests/authn_request_test.py +++ b/tests/src/OneLogin/saml2_tests/authn_request_test.py @@ -339,3 +339,14 @@ def testAttributeConsumingService(self): inflated = compat.to_string(OneLogin_Saml2_Utils.decode_base64_and_inflate(authn_request_encoded)) self.assertRegex(inflated, 'AttributeConsumingServiceIndex="1"') + + def testMultipleAssertionConsumerServices(self): + settings_data = self.loadSettingsJSON('settings9.json') + settings = OneLogin_Saml2_Settings(settings_data) + self.assertEqual(len(settings.get_errors()), 0) + + authn_request = OneLogin_Saml2_Authn_Request(settings, acs_index=456) + authn_request_encoded = authn_request.get_request() + inflated = compat.to_string(OneLogin_Saml2_Utils.decode_base64_and_inflate(authn_request_encoded)) + + self.assertRegex(inflated, 'AssertionConsumerServiceURL="http://stuff.com/endpoints/endpoints/acs2.php">')