diff --git a/src/onelogin/saml2/auth.py b/src/onelogin/saml2/auth.py index a67e2071..7f0880cd 100644 --- a/src/onelogin/saml2/auth.py +++ b/src/onelogin/saml2/auth.py @@ -11,8 +11,10 @@ """ +import logging import xmlsec + from onelogin.saml2 import compat from onelogin.saml2.authn_request import OneLogin_Saml2_Authn_Request from onelogin.saml2.constants import OneLogin_Saml2_Constants @@ -24,6 +26,9 @@ from onelogin.saml2.xmlparser import tostring +logger = logging.getLogger(__name__) + + class OneLogin_Saml2_Auth(object): """ @@ -378,6 +383,41 @@ def get_last_authn_contexts(self): """ return self.__last_authn_contexts + def _create_authn_request( + self, force_authn=False, is_passive=False, set_nameid_policy=True, name_id_value_req=None + ): + authn_request = self.authn_request_class(self.__settings, force_authn, is_passive, set_nameid_policy, name_id_value_req) + + self.__last_request = authn_request.get_xml() + self.__last_request_id = authn_request.get_id() + return authn_request + + def login_post(self, return_to=None, **authn_kwargs): + authn_request = self._create_authn_request(**authn_kwargs) + + url = self.get_sso_url() + data = authn_request.get_request(deflate=False, base64_encode=False) + saml_request = OneLogin_Saml2_Utils.b64encode( + OneLogin_Saml2_Utils.add_sign( + data, + self.__settings.get_sp_key(), self.__settings.get_sp_cert(), + sign_algorithm=OneLogin_Saml2_Constants.RSA_SHA256, + digest_algorithm=OneLogin_Saml2_Constants.SHA256,), + + ) + logger.debug( + "Returning form-data to the user for a AuthNRequest to %s with SAMLRequest %s", + url, OneLogin_Saml2_Utils.b64decode(saml_request).decode('utf-8') + ) + parameters = {'SAMLRequest': saml_request} + + if return_to is not None: + parameters['RelayState'] = return_to + else: + parameters['RelayState'] = OneLogin_Saml2_Utils.get_self_url_no_query(self.__request_data) + + return url, parameters + def login(self, return_to=None, force_authn=False, is_passive=False, set_nameid_policy=True, name_id_value_req=None): """ Initiates the SSO process. @@ -400,9 +440,10 @@ def login(self, return_to=None, force_authn=False, is_passive=False, set_nameid_ :returns: Redirection URL :rtype: string """ - authn_request = self.authn_request_class(self.__settings, force_authn, is_passive, set_nameid_policy, name_id_value_req) - self.__last_request = authn_request.get_xml() - self.__last_request_id = authn_request.get_id() + authn_request = self._create_authn_request( + force_authn=force_authn, is_passive=is_passive, + set_nameid_policy=set_nameid_policy, name_id_value_req=name_id_value_req + ) saml_request = authn_request.get_request() parameters = {'SAMLRequest': saml_request} diff --git a/src/onelogin/saml2/authn_request.py b/src/onelogin/saml2/authn_request.py index 48ad9d2a..38c0e686 100644 --- a/src/onelogin/saml2/authn_request.py +++ b/src/onelogin/saml2/authn_request.py @@ -134,7 +134,7 @@ def _generate_request_id(self): """ return OneLogin_Saml2_Utils.generate_unique_id() - def get_request(self, deflate=True): + def get_request(self, deflate=True, base64_encode=True): """ Returns unsigned AuthnRequest. :param deflate: It makes the deflate process optional @@ -143,9 +143,12 @@ def get_request(self, deflate=True): :rtype: str object """ if deflate: + assert base64_encode is True, "Deflate without base64 encoding is not supported" request = OneLogin_Saml2_Utils.deflate_and_base64_encode(self.__authn_request) - else: + elif base64_encode: request = OneLogin_Saml2_Utils.b64encode(self.__authn_request) + else: + request = self.__authn_request return request def get_id(self): diff --git a/tests/src/OneLogin/saml2_tests/auth_test.py b/tests/src/OneLogin/saml2_tests/auth_test.py index ea8cf1d3..e456de2e 100644 --- a/tests/src/OneLogin/saml2_tests/auth_test.py +++ b/tests/src/OneLogin/saml2_tests/auth_test.py @@ -605,6 +605,37 @@ def testLoginWithRelayState(self): self.assertIn('RelayState', parsed_query) self.assertIn(relay_state, parsed_query['RelayState']) + def testLoginPost(self): + settings_info = self.loadSettingsJSON() + request_data = self.get_request() + auth = OneLogin_Saml2_Auth(self.get_request(), old_settings=settings_info) + + url, parameters = auth.login_post() + self.assertEqual(url, 'http://idp.example.com/SSOService.php') + # self.assertEqual(parameters['RelayState'], relay_state) + saml_request = b64decode(parameters['SAMLRequest']) + self.assertTrue(saml_request.startswith(b'', saml_request) + self.assertIn(b'', saml_request) + + def testLoginPostWithRelayState(self): + settings_info = self.loadSettingsJSON() + auth = OneLogin_Saml2_Auth(self.get_request(), old_settings=settings_info) + relay_state = 'http://sp.example.com' + + url, parameters = auth.login_post(relay_state) + self.assertEqual(url, 'http://idp.example.com/SSOService.php') + self.assertEqual(parameters['RelayState'], relay_state) + saml_request = b64decode(parameters['SAMLRequest']) + self.assertTrue(saml_request.startswith(b'', saml_request) + self.assertIn(b'', saml_request) + def testLoginSigned(self): """ Tests the login method of the OneLogin_Saml2_Auth class