Skip to content

Commit ba76d90

Browse files
author
Roland Hedberg
committed
Added code that allows encryption of an assertion in a response.
1 parent f67d61c commit ba76d90

File tree

1 file changed

+155
-16
lines changed

1 file changed

+155
-16
lines changed

src/saml2/sigver.py

Lines changed: 155 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from Crypto.Util.asn1 import DerSequence
3434
from Crypto.PublicKey import RSA
3535
from saml2.cert import OpenSSLWrapper
36+
from saml2.saml import EncryptedAssertion
3637
from saml2.samlp import Response
3738

3839
import xmldsig as ds
@@ -70,7 +71,8 @@
7071
RSA_SHA1 = "http://www.w3.org/2000/09/xmldsig#rsa-sha1"
7172
RSA_1_5 = "http://www.w3.org/2001/04/xmlenc#rsa-1_5"
7273
TRIPLE_DES_CBC = "http://www.w3.org/2001/04/xmlenc#tripledes-cbc"
73-
74+
XMLTAG = "<?xml version='1.0'?>"
75+
PREFIX = "<?xml version='1.0' encoding='UTF-8'?>"
7476

7577

7678
class SigverError(SAMLError):
@@ -97,6 +99,10 @@ class DecryptError(XmlsecError):
9799
pass
98100

99101

102+
class EncryptError(XmlsecError):
103+
pass
104+
105+
100106
class BadSignature(SigverError):
101107
"""The signature is invalid."""
102108
pass
@@ -106,6 +112,25 @@ class CertificateError(SigverError):
106112
pass
107113

108114

115+
def rm_xmltag(statement):
116+
try:
117+
_t = statement.startswith(XMLTAG)
118+
except TypeError:
119+
statement = statement.decode("utf8")
120+
_t = statement.startswith(XMLTAG)
121+
122+
if _t:
123+
statement = statement[len(XMLTAG):]
124+
if statement[0] == '\n':
125+
statement = statement[1:]
126+
elif statement.startswith(PREFIX):
127+
statement = statement[len(PREFIX):]
128+
if statement[0] == '\n':
129+
statement = statement[1:]
130+
131+
return statement
132+
133+
109134
def signed(item):
110135
if SIG in item.c_children.keys() and item.signature:
111136
return True
@@ -228,7 +253,7 @@ def _instance(klass, ava, seccont, base64encode=False, elements_to_sign=None):
228253
instance = klass()
229254

230255
for prop in instance.c_attributes.values():
231-
#print "# %s" % (prop)
256+
#print "# %s" % (prop)
232257
if prop in ava:
233258
if isinstance(ava[prop], bool):
234259
setattr(instance, prop, "%s" % ava[prop])
@@ -290,6 +315,7 @@ def signed_instance_factory(instance, seccont, elements_to_sign=None):
290315
else:
291316
return instance
292317

318+
293319
# --------------------------------------------------------------------------
294320

295321

@@ -305,7 +331,7 @@ def create_id():
305331
return ret
306332

307333

308-
def make_temp(string, suffix="", decode=True):
334+
def make_temp(string, suffix="", decode=True, delete=True):
309335
""" xmlsec needs files in some cases where only strings exist, hence the
310336
need for this function. It creates a temporary file with the
311337
string as only content.
@@ -319,7 +345,7 @@ def make_temp(string, suffix="", decode=True):
319345
close the file) and filename (which is for instance needed by the
320346
xmlsec function).
321347
"""
322-
ntf = NamedTemporaryFile(suffix=suffix)
348+
ntf = NamedTemporaryFile(suffix=suffix, delete=delete)
323349
if decode:
324350
ntf.write(base64.b64decode(string))
325351
else:
@@ -428,6 +454,7 @@ def cert_from_instance(instance):
428454
ignore_age=True)
429455
return []
430456

457+
431458
# =============================================================================
432459

433460

@@ -469,12 +496,14 @@ def key_from_key_value_dict(key_info):
469496
res.append(key)
470497
return res
471498

499+
472500
# =============================================================================
473501

474502

475503
#def rsa_load(filename):
476504
# """Read a PEM-encoded RSA key pair from a file."""
477-
# return M2Crypto.RSA.load_key(filename, M2Crypto.util.no_passphrase_callback)
505+
# return M2Crypto.RSA.load_key(filename, M2Crypto.util
506+
# .no_passphrase_callback)
478507
#
479508
#
480509
#def rsa_loads(key):
@@ -594,10 +623,12 @@ def verify_redirect_signature(info, cert):
594623
_order = RESP_ORDER
595624
else:
596625
raise Unsupported(
597-
"Verifying signature on something that should not be signed")
626+
"Verifying signature on something that should not be "
627+
"signed")
598628
args = info.copy()
599629
del args["Signature"] # everything but the signature
600-
string = "&".join([urllib.urlencode({k: args[k][0]}) for k in _order])
630+
string = "&".join(
631+
[urllib.urlencode({k: args[k][0]}) for k in _order])
601632
_key = extract_rsa_key_from_x509_cert(pem_format(cert))
602633
_sign = base64.b64decode(info["Signature"][0])
603634
try:
@@ -660,6 +691,9 @@ def version(self):
660691
def encrypt(self, text, recv_key, template, key_type):
661692
raise NotImplementedError()
662693

694+
def encrypt_assertion(self, statement, recv_key, key_type):
695+
raise NotImplementedError()
696+
663697
def decrypt(self, enctext, key_file):
664698
raise NotImplementedError()
665699

@@ -672,6 +706,10 @@ def validate_signature(self, enctext, cert_file, cert_type, node_name,
672706
raise NotImplementedError()
673707

674708

709+
ASSERT_XPATH = ''.join(["/*[local-name()=\"%s\"]" % v for v in [
710+
"Response", "EncryptedAssertion", "Assertion"]])
711+
712+
675713
class CryptoBackendXmlSec1(CryptoBackend):
676714
"""
677715
CryptoBackend implementation using external binary xmlsec1 to sign
@@ -705,6 +743,38 @@ def encrypt(self, text, recv_key, template, key_type):
705743
validate_output=False)
706744
return output
707745

746+
def encrypt_assertion(self, statement, enc_key, template,
747+
key_type="des-192"):
748+
"""
749+
--pubkey-cert-pem ../../example/idp2/pki/mycert.pem \
750+
--session-key des-192 --xml-data pre_saml2_assertion.xml \
751+
--node-xpath '/*[local-name()="Response"]/*[local-name(
752+
)="EncryptedAssertion"]/*[local-name()="Assertion"]' \
753+
enc-element-3des-kt-rsa1_5.tmpl > enc_3des_rsa_assertion.xml
754+
755+
:param statement:
756+
:param cert_file:
757+
:param cert_type:
758+
:return:
759+
"""
760+
statement = pre_encrypt_assertion(statement)
761+
_, fil = make_temp("%s" % statement, decode=False, delete=False)
762+
_, tmpl = make_temp("%s" % template, decode=False)
763+
764+
com_list = [self.xmlsec, "encrypt", "--pubkey-cert-pem", enc_key,
765+
"--session-key", key_type, "--xml-data", fil,
766+
"--node-xpath", ASSERT_XPATH]
767+
768+
(_stdout, _stderr, output) = self._run_xmlsec(com_list, [tmpl],
769+
exception=EncryptError,
770+
validate_output=False)
771+
772+
os.unlink(fil)
773+
if not output:
774+
raise EncryptError(_stderr)
775+
776+
return output
777+
708778
def decrypt(self, enctext, key_file):
709779
logger.debug("Decrypt input len: %d" % len(enctext))
710780
_, fil = make_temp("%s" % enctext, decode=False)
@@ -1005,7 +1075,7 @@ def __init__(self, security_context, cert_file=None, cert_type="pem",
10051075
self._cert_info = None
10061076
self._generate_cert_func_active = False
10071077
if generate_cert_info is not None and len(self._cert_str) > 0 and \
1008-
len(self._key_str) > 0 and tmp_key_file is not \
1078+
len(self._key_str) > 0 and tmp_key_file is not \
10091079
None and tmp_cert_file is not None:
10101080
self._generate_cert = True
10111081
self._cert_info = generate_cert_info
@@ -1037,9 +1107,12 @@ def update_cert(self, active=False, client_crt=None):
10371107
elif self._cert_handler_extra_class is not None and \
10381108
self._cert_handler_extra_class.use_generate_cert_func():
10391109
(self._tmp_cert_str, self._tmp_key_str) = \
1040-
self._cert_handler_extra_class.generate_cert(self._cert_info, self._cert_str, self._key_str)
1110+
self._cert_handler_extra_class.generate_cert(
1111+
self._cert_info, self._cert_str, self._key_str)
10411112
else:
1042-
self._tmp_cert_str, self._tmp_key_str = self._osw.create_certificate(self._cert_info, request=True)
1113+
self._tmp_cert_str, self._tmp_key_str = self._osw\
1114+
.create_certificate(
1115+
self._cert_info, request=True)
10431116
self._tmp_cert_str = self._osw.create_cert_signed_certificate(
10441117
self._cert_str, self._key_str, self._tmp_cert_str)
10451118
valid, mess = self._osw.verify(self._cert_str,
@@ -1123,6 +1196,19 @@ def encrypt(self, text, recv_key="", template="", key_type=""):
11231196

11241197
return self.crypto.encrypt(text, recv_key, template, key_type)
11251198

1199+
def encrypt_assertion(self, statement, cert_file, cert_type="pem"):
1200+
"""
1201+
--pubkey-cert-pem ../../example/idp2/pki/mycert.pem \
1202+
--session-key des-192 --xml-data pre_saml2_assertion.xml \
1203+
--node-xpath '/*[local-name()="Response"]/*[local-name(
1204+
)="EncryptedAssertion"]/*[local-name()="Assertion"]' \
1205+
enc-element-3des-kt-rsa1_5.tmpl > enc_3des_rsa_assertion.xml
1206+
:param statement:
1207+
:param cert_file:
1208+
:param cert_type:
1209+
:return:
1210+
"""
1211+
11261212
def decrypt(self, enctext):
11271213
""" Decrypting an encrypted text by the use of a private key.
11281214
@@ -1237,7 +1323,6 @@ def _check_signature(self, decoded_xml, item, node_name=NODE_NAME,
12371323
if not self.cert_handler.verify_cert(last_pem_file):
12381324
raise CertificateError("Invalid certificate!")
12391325

1240-
12411326
return item
12421327

12431328
def check_signature(self, item, node_name=NODE_NAME, origdoc=None,
@@ -1390,9 +1475,10 @@ def correctly_signed_response(self, decoded_xml, must=False, origdoc=None):
13901475
origdoc)
13911476

13921477
if isinstance(response, Response) and (response.assertion or
1393-
response.encrypted_assertion):
1478+
response.encrypted_assertion):
13941479
# Try to find the signing cert in the assertion
1395-
for assertion in (response.assertion or response.encrypted_assertion):
1480+
for assertion in (
1481+
response.assertion or response.encrypted_assertion):
13961482
if response.encrypted_assertion:
13971483
decoded_xml = self.decrypt(
13981484
assertion.encrypted_data.to_string())
@@ -1551,28 +1637,81 @@ def pre_signature_part(ident, public_key=None, identifier=None):
15511637
return signature
15521638

15531639

1554-
def pre_encryption_part(msg_enc=TRIPLE_DES_CBC, key_enc=RSA_1_5):
1640+
# <?xml version="1.0" encoding="UTF-8"?>
1641+
# <EncryptedData Id="ED" Type="http://www.w3.org/2001/04/xmlenc#Element"
1642+
# xmlns="http://www.w3.org/2001/04/xmlenc#">
1643+
# <EncryptionMethod Algorithm="http://www.w3
1644+
# .org/2001/04/xmlenc#tripledes-cbc"/>
1645+
# <ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
1646+
# <EncryptedKey Id="EK" xmlns="http://www.w3.org/2001/04/xmlenc#">
1647+
# <EncryptionMethod Algorithm="http://www.w3
1648+
# .org/2001/04/xmlenc#rsa-1_5"/>
1649+
# <ds:KeyInfo xmlns:ds="http://www.w3.org/2000/09/xmldsig#">
1650+
# <ds:KeyName>my-rsa-key</ds:KeyName>
1651+
# </ds:KeyInfo>
1652+
# <CipherData>
1653+
# <CipherValue>
1654+
# </CipherValue>
1655+
# </CipherData>
1656+
# <ReferenceList>
1657+
# <DataReference URI="#ED"/>
1658+
# </ReferenceList>
1659+
# </EncryptedKey>
1660+
# </ds:KeyInfo>
1661+
# <CipherData>
1662+
# <CipherValue>
1663+
# </CipherValue>
1664+
# </CipherData>
1665+
# </EncryptedData>
1666+
1667+
def pre_encryption_part(msg_enc=TRIPLE_DES_CBC, key_enc=RSA_1_5,
1668+
key_name="my-rsa-key"):
15551669
"""
15561670
15571671
:param msg_enc:
15581672
:param key_enc:
1673+
:param key_name:
15591674
:return:
15601675
"""
15611676
msg_encryption_method = EncryptionMethod(algorithm=msg_enc)
15621677
key_encryption_method = EncryptionMethod(algorithm=key_enc)
1563-
encrypted_key = EncryptedKey(encryption_method=key_encryption_method,
1678+
encrypted_key = EncryptedKey(id="EK",
1679+
encryption_method=key_encryption_method,
15641680
key_info=ds.KeyInfo(
1565-
key_name=ds.KeyName(text="")),
1681+
key_name=ds.KeyName(text=key_name)),
15661682
cipher_data=CipherData(
15671683
cipher_value=CipherValue(text="")))
15681684
key_info = ds.KeyInfo(encrypted_key=encrypted_key)
15691685
encrypted_data = EncryptedData(
1686+
id="ED",
1687+
type="http://www.w3.org/2001/04/xmlenc#Element",
15701688
encryption_method=msg_encryption_method,
15711689
key_info=key_info,
15721690
cipher_data=CipherData(cipher_value=CipherValue(text="")))
15731691
return encrypted_data
15741692

15751693

1694+
def pre_encrypt_assertion(response):
1695+
"""
1696+
Move the assertion to within a encrypted_assertion
1697+
:param response: The response with one assertion
1698+
:return: The response but now with the assertion within an
1699+
encrypted_assertion.
1700+
"""
1701+
assertion = response.assertion
1702+
response.assertion = None
1703+
response.encrypted_assertion = EncryptedAssertion()
1704+
response.encrypted_assertion.add_extension_element(assertion)
1705+
# txt = "%s" % response
1706+
# _ass = "%s" % assertion
1707+
# _ass = rm_xmltag(_ass)
1708+
# txt.replace(
1709+
# "<ns1:EncryptedAssertion/>",
1710+
# "<ns1:EncryptedAssertion>%s</ns1:EncryptedAssertion>" % _ass)
1711+
1712+
return response
1713+
1714+
15761715
def response_factory(sign=False, encrypt=False, **kwargs):
15771716
response = samlp.Response(id=sid(), version=VERSION,
15781717
issue_instant=instant())

0 commit comments

Comments
 (0)