diff --git a/src/onelogin/saml2/auth.py b/src/onelogin/saml2/auth.py
index a67e2071..7f0880cd 100644
--- a/src/onelogin/saml2/auth.py
+++ b/src/onelogin/saml2/auth.py
@@ -11,8 +11,10 @@
"""
+import logging
import xmlsec
+
from onelogin.saml2 import compat
from onelogin.saml2.authn_request import OneLogin_Saml2_Authn_Request
from onelogin.saml2.constants import OneLogin_Saml2_Constants
@@ -24,6 +26,9 @@
from onelogin.saml2.xmlparser import tostring
+logger = logging.getLogger(__name__)
+
+
class OneLogin_Saml2_Auth(object):
"""
@@ -378,6 +383,41 @@ def get_last_authn_contexts(self):
"""
return self.__last_authn_contexts
+ def _create_authn_request(
+ self, force_authn=False, is_passive=False, set_nameid_policy=True, name_id_value_req=None
+ ):
+ authn_request = self.authn_request_class(self.__settings, force_authn, is_passive, set_nameid_policy, name_id_value_req)
+
+ self.__last_request = authn_request.get_xml()
+ self.__last_request_id = authn_request.get_id()
+ return authn_request
+
+ def login_post(self, return_to=None, **authn_kwargs):
+ authn_request = self._create_authn_request(**authn_kwargs)
+
+ url = self.get_sso_url()
+ data = authn_request.get_request(deflate=False, base64_encode=False)
+ saml_request = OneLogin_Saml2_Utils.b64encode(
+ OneLogin_Saml2_Utils.add_sign(
+ data,
+ self.__settings.get_sp_key(), self.__settings.get_sp_cert(),
+ sign_algorithm=OneLogin_Saml2_Constants.RSA_SHA256,
+ digest_algorithm=OneLogin_Saml2_Constants.SHA256,),
+
+ )
+ logger.debug(
+ "Returning form-data to the user for a AuthNRequest to %s with SAMLRequest %s",
+ url, OneLogin_Saml2_Utils.b64decode(saml_request).decode('utf-8')
+ )
+ parameters = {'SAMLRequest': saml_request}
+
+ if return_to is not None:
+ parameters['RelayState'] = return_to
+ else:
+ parameters['RelayState'] = OneLogin_Saml2_Utils.get_self_url_no_query(self.__request_data)
+
+ return url, parameters
+
def login(self, return_to=None, force_authn=False, is_passive=False, set_nameid_policy=True, name_id_value_req=None):
"""
Initiates the SSO process.
@@ -400,9 +440,10 @@ def login(self, return_to=None, force_authn=False, is_passive=False, set_nameid_
:returns: Redirection URL
:rtype: string
"""
- authn_request = self.authn_request_class(self.__settings, force_authn, is_passive, set_nameid_policy, name_id_value_req)
- self.__last_request = authn_request.get_xml()
- self.__last_request_id = authn_request.get_id()
+ authn_request = self._create_authn_request(
+ force_authn=force_authn, is_passive=is_passive,
+ set_nameid_policy=set_nameid_policy, name_id_value_req=name_id_value_req
+ )
saml_request = authn_request.get_request()
parameters = {'SAMLRequest': saml_request}
diff --git a/src/onelogin/saml2/authn_request.py b/src/onelogin/saml2/authn_request.py
index 48ad9d2a..38c0e686 100644
--- a/src/onelogin/saml2/authn_request.py
+++ b/src/onelogin/saml2/authn_request.py
@@ -134,7 +134,7 @@ def _generate_request_id(self):
"""
return OneLogin_Saml2_Utils.generate_unique_id()
- def get_request(self, deflate=True):
+ def get_request(self, deflate=True, base64_encode=True):
"""
Returns unsigned AuthnRequest.
:param deflate: It makes the deflate process optional
@@ -143,9 +143,12 @@ def get_request(self, deflate=True):
:rtype: str object
"""
if deflate:
+ assert base64_encode is True, "Deflate without base64 encoding is not supported"
request = OneLogin_Saml2_Utils.deflate_and_base64_encode(self.__authn_request)
- else:
+ elif base64_encode:
request = OneLogin_Saml2_Utils.b64encode(self.__authn_request)
+ else:
+ request = self.__authn_request
return request
def get_id(self):
diff --git a/tests/src/OneLogin/saml2_tests/auth_test.py b/tests/src/OneLogin/saml2_tests/auth_test.py
index ea8cf1d3..e456de2e 100644
--- a/tests/src/OneLogin/saml2_tests/auth_test.py
+++ b/tests/src/OneLogin/saml2_tests/auth_test.py
@@ -605,6 +605,37 @@ def testLoginWithRelayState(self):
self.assertIn('RelayState', parsed_query)
self.assertIn(relay_state, parsed_query['RelayState'])
+ def testLoginPost(self):
+ settings_info = self.loadSettingsJSON()
+ request_data = self.get_request()
+ auth = OneLogin_Saml2_Auth(self.get_request(), old_settings=settings_info)
+
+ url, parameters = auth.login_post()
+ self.assertEqual(url, 'http://idp.example.com/SSOService.php')
+ # self.assertEqual(parameters['RelayState'], relay_state)
+ saml_request = b64decode(parameters['SAMLRequest'])
+ self.assertTrue(saml_request.startswith(b'', saml_request)
+ self.assertIn(b'', saml_request)
+
+ def testLoginPostWithRelayState(self):
+ settings_info = self.loadSettingsJSON()
+ auth = OneLogin_Saml2_Auth(self.get_request(), old_settings=settings_info)
+ relay_state = 'http://sp.example.com'
+
+ url, parameters = auth.login_post(relay_state)
+ self.assertEqual(url, 'http://idp.example.com/SSOService.php')
+ self.assertEqual(parameters['RelayState'], relay_state)
+ saml_request = b64decode(parameters['SAMLRequest'])
+ self.assertTrue(saml_request.startswith(b'', saml_request)
+ self.assertIn(b'', saml_request)
+
def testLoginSigned(self):
"""
Tests the login method of the OneLogin_Saml2_Auth class