Skip to content

Commit e2b0461

Browse files
author
Hans Hörberg
committed
Pysaml can now decrypt multiple encrypted assertions with multiple advice elements with multiple encrypted assertions.
1 parent e70835b commit e2b0461

File tree

6 files changed

+575
-69
lines changed

6 files changed

+575
-69
lines changed

src/saml2/__init__.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -591,13 +591,14 @@ def get_prefix_map(self, elements):
591591

592592
def get_xml_string_with_self_contained_assertion_within_advice_encrypted_assertion(self, assertion_tag, advice_tag):
593593
for tmp_encrypted_assertion in self.assertion.advice.encrypted_assertion:
594-
prefix_map = self.get_prefix_map([tmp_encrypted_assertion._to_element_tree().
595-
find(assertion_tag)])
596-
597-
tree = self._to_element_tree()
598-
599-
self.set_prefixes(tree.find(assertion_tag).find(advice_tag).find(tmp_encrypted_assertion._to_element_tree()
600-
.tag).find(assertion_tag), prefix_map)
594+
if tmp_encrypted_assertion.encrypted_data is None:
595+
prefix_map = self.get_prefix_map([tmp_encrypted_assertion._to_element_tree().find(assertion_tag)])
596+
tree = self._to_element_tree()
597+
encs = tree.find(assertion_tag).find(advice_tag).findall(tmp_encrypted_assertion._to_element_tree().tag)
598+
for enc in encs:
599+
assertion = enc.find(assertion_tag)
600+
if assertion is not None:
601+
self.set_prefixes(assertion, prefix_map)
601602

602603
return ElementTree.tostring(tree, encoding="UTF-8")
603604

src/saml2/client_base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -542,7 +542,7 @@ def create_name_id_mapping_request(self, name_id_policy,
542542
# ======== response handling ===========
543543

544544
def parse_authn_request_response(self, xmlstr, binding, outstanding=None,
545-
outstanding_certs=None, decrypt=True, pefim=False):
545+
outstanding_certs=None):
546546
""" Deal with an AuthnResponse
547547
548548
:param xmlstr: The reply as a xml string
@@ -573,12 +573,11 @@ def parse_authn_request_response(self, xmlstr, binding, outstanding=None,
573573
"attribute_converters": self.config.attribute_converters,
574574
"allow_unknown_attributes":
575575
self.config.allow_unknown_attributes,
576-
"decrypt": decrypt
577576
}
578577
try:
579578
resp = self._parse_response(xmlstr, AuthnResponse,
580579
"assertion_consumer_service",
581-
binding, pefim=pefim, **kwargs)
580+
binding, **kwargs)
582581
except StatusError as err:
583582
logger.error("SAML status error: %s" % err)
584583
raise

src/saml2/entity.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,7 @@ def parse_manage_name_id_request_response(self, string,
978978
# ------------------------------------------------------------------------
979979

980980
def _parse_response(self, xmlstr, response_cls, service, binding,
981-
outstanding_certs=None, pefim=False, **kwargs):
981+
outstanding_certs=None, **kwargs):
982982
""" Deal with a Response
983983
984984
:param xmlstr: The response as a xml string
@@ -1040,23 +1040,23 @@ def _parse_response(self, xmlstr, response_cls, service, binding,
10401040
logger.debug("XMLSTR: %s" % xmlstr)
10411041

10421042
if response:
1043+
keys = None
10431044
if outstanding_certs:
10441045
try:
10451046
cert = outstanding_certs[response.in_response_to]
10461047
except KeyError:
1047-
key_file = ""
1048+
keys = None
10481049
else:
1049-
_, key_file = make_temp("%s" % cert["key"],
1050-
decode=False)
1051-
else:
1052-
key_file = ""
1050+
if not isinstance(cert, list):
1051+
cert = [cert]
1052+
keys = []
1053+
for _cert in cert:
1054+
keys.append(_cert["key"])
10531055
only_identity_in_encrypted_assertion = False
10541056
if "only_identity_in_encrypted_assertion" in kwargs:
10551057
only_identity_in_encrypted_assertion = kwargs["only_identity_in_encrypted_assertion"]
1056-
decrypt = True
1057-
if "decrypt" in kwargs:
1058-
decrypt = kwargs["decrypt"]
1059-
response = response.verify(key_file, decrypt=decrypt, pefim=pefim)
1058+
1059+
response = response.verify(keys)
10601060

10611061
if not response:
10621062
return None

src/saml2/response.py

Lines changed: 74 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def _verify(self):
395395
def loads(self, xmldata, decode=True, origxml=None):
396396
return self._loads(xmldata, decode, origxml)
397397

398-
def verify(self, key_file="", decrypt=True, pefim=False):
398+
def verify(self, keys=None):
399399
try:
400400
return self._verify()
401401
except AssertionError:
@@ -636,18 +636,19 @@ def get_identity(self):
636636
637637
"""
638638
ava = {}
639-
if self.assertion.advice:
640-
if self.assertion.advice.assertion:
641-
for tmp_assertion in self.assertion.advice.assertion:
642-
if tmp_assertion.attribute_statement:
643-
assert len(tmp_assertion.attribute_statement) == 1
644-
ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0]))
645-
if self.assertion.attribute_statement:
646-
assert len(self.assertion.attribute_statement) == 1
647-
_attr_statem = self.assertion.attribute_statement[0]
648-
ava.update(self.read_attribute_statement(_attr_statem))
649-
if not ava:
650-
logger.error("Missing Attribute Statement")
639+
for _assertion in self.assertions:
640+
if _assertion.advice:
641+
if _assertion.advice.assertion:
642+
for tmp_assertion in _assertion.advice.assertion:
643+
if tmp_assertion.attribute_statement:
644+
assert len(tmp_assertion.attribute_statement) == 1
645+
ava.update(self.read_attribute_statement(tmp_assertion.attribute_statement[0]))
646+
if _assertion.attribute_statement:
647+
assert len(_assertion.attribute_statement) == 1
648+
_attr_statem = _assertion.attribute_statement[0]
649+
ava.update(self.read_attribute_statement(_attr_statem))
650+
if not ava:
651+
logger.error("Missing Attribute Statement")
651652
return ava
652653

653654
def _bearer_confirmed(self, data):
@@ -796,14 +797,14 @@ def _assertion(self, assertion, verified=False):
796797
logger.exception("get subject")
797798
raise
798799

799-
def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None):
800+
def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None, verified=False):
800801
res = []
801802
for encrypted_assertion in encrypted_assertions:
802803
if encrypted_assertion.extension_elements:
803804
assertions = extension_elements_to_elements(
804805
encrypted_assertion.extension_elements, [saml, samlp])
805806
for assertion in assertions:
806-
if assertion.signature:
807+
if assertion.signature and not verified:
807808
if not self.sec.check_signature(
808809
assertion, origdoc=decr_txt,
809810
node_name=class_name(assertion), issuer=issuer):
@@ -812,7 +813,35 @@ def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None):
812813
res.append(assertion)
813814
return res
814815

815-
def parse_assertion(self, key_file="", decrypt=True, pefim=False):
816+
def find_encrypt_data_assertion(self, enc_assertions):
817+
for _assertion in enc_assertions:
818+
if _assertion.encrypted_data is not None:
819+
return True
820+
821+
def find_encrypt_data_assertion_list(self, _assertions):
822+
for _assertion in _assertions:
823+
if _assertion.advice:
824+
if _assertion.advice.encrypted_assertion:
825+
res = self.find_encrypt_data_assertion(_assertion.advice.encrypted_assertion)
826+
if res:
827+
return True
828+
829+
def find_encrypt_data(self, resp):
830+
_has_encrypt_data = False
831+
if resp.encrypted_assertion:
832+
res = self.find_encrypt_data_assertion(resp.encrypted_assertion)
833+
if res:
834+
return True
835+
if resp.assertion:
836+
for tmp_assertion in resp.assertion:
837+
if tmp_assertion.advice:
838+
if tmp_assertion.advice.encrypted_assertion:
839+
res = self.find_encrypt_data_assertion(tmp_assertion.advice.encrypted_assertion)
840+
if res:
841+
return True
842+
return False
843+
844+
def parse_assertion(self, keys=None):
816845
if self.context == "AuthnQuery":
817846
# can contain one or more assertions
818847
pass
@@ -823,30 +852,39 @@ def parse_assertion(self, key_file="", decrypt=True, pefim=False):
823852
except AssertionError:
824853
raise Exception("No assertion part")
825854

826-
has_encrypted_assertions = self.response.encrypted_assertion
827-
if not has_encrypted_assertions and self.response.assertion:
828-
for tmp_assertion in self.response.assertion:
829-
if tmp_assertion.advice:
830-
if tmp_assertion.advice.encrypted_assertion:
831-
has_encrypted_assertions = True
832-
break
855+
has_encrypted_assertions = self.find_encrypt_data(self.response) #self.response.encrypted_assertion
856+
#if not has_encrypted_assertions and self.response.assertion:
857+
# for tmp_assertion in self.response.assertion:
858+
# if tmp_assertion.advice:
859+
# if tmp_assertion.advice.encrypted_assertion:
860+
# has_encrypted_assertions = True
861+
# break
833862

834863
if self.response.assertion:
835864
logger.debug("***Unencrypted assertion***")
836865
for assertion in self.response.assertion:
837866
if not self._assertion(assertion, False):
838867
return False
839868

840-
if has_encrypted_assertions and decrypt:
869+
if has_encrypted_assertions:
841870
_enc_assertions = []
842871
logger.debug("***Encrypted assertion/-s***")
843-
decr_text = self.sec.decrypt(self.xmlstr, key_file)
844-
resp = samlp.response_from_string(decr_text)
872+
decr_text = "%s" % self.response
873+
resp = self.response
874+
while self.find_encrypt_data(resp):
875+
decr_text = self.sec.decrypt_keys(decr_text, keys)
876+
resp = samlp.response_from_string(decr_text)
845877
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text)
846-
decr_text = self.sec.decrypt(decr_text, key_file)
847-
resp = samlp.response_from_string(decr_text)
878+
while self.find_encrypt_data(resp) or self.find_encrypt_data_assertion_list(_enc_assertions):
879+
decr_text = self.sec.decrypt_keys(decr_text, keys)
880+
resp = samlp.response_from_string(decr_text)
881+
_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
882+
#_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
883+
all_assertions = _enc_assertions
848884
if resp.assertion:
849-
for tmp_ass in resp.assertion:
885+
all_assertions = all_assertions + resp.assertion
886+
if len(all_assertions) > 0:
887+
for tmp_ass in all_assertions:
850888
if tmp_ass.advice and tmp_ass.advice.encrypted_assertion:
851889
advice_res = self.decrypt_assertions(tmp_ass.advice.encrypted_assertion,
852890
decr_text,
@@ -855,21 +893,20 @@ def parse_assertion(self, key_file="", decrypt=True, pefim=False):
855893
tmp_ass.advice.assertion.extend(advice_res)
856894
else:
857895
tmp_ass.advice.assertion = advice_res
858-
if not pefim:
859-
_enc_assertions.extend(advice_res)
860896
tmp_ass.advice.encrypted_assertion = []
861897
self.response.assertion = resp.assertion
862898
for assertion in _enc_assertions:
863899
if not self._assertion(assertion, True):
864900
return False
901+
else:
902+
self.assertions.append(assertion)
903+
865904
self.xmlstr = decr_text
866905
self.response.encrypted_assertion = []
867906

868907
if self.response.assertion:
869908
for assertion in self.response.assertion:
870-
if assertion.advice and assertion.advice.assertion:
871-
for advice_assertion in assertion.advice.assertion:
872-
self.assertions.append(assertion)
909+
self.assertions.append(assertion)
873910

874911
if self.assertions and len(self.assertions) > 0:
875912
self.assertion = self.assertions[0]
@@ -880,7 +917,7 @@ def parse_assertion(self, key_file="", decrypt=True, pefim=False):
880917

881918
return True
882919

883-
def verify(self, key_file="", decrypt=True, pefim=False):
920+
def verify(self, keys=None):
884921
""" Verify that the assertion is syntactically correct and
885922
the signature is correct if present.
886923
:param key_file: If not the default key file should be used this is it.
@@ -898,7 +935,7 @@ def verify(self, key_file="", decrypt=True, pefim=False):
898935
if not isinstance(self.response, samlp.Response):
899936
return self
900937

901-
if self.parse_assertion(key_file, decrypt=decrypt, pefim=pefim):
938+
if self.parse_assertion(keys):
902939
return self
903940
else:
904941
logger.error("Could not parse the assertion")
@@ -1126,7 +1163,7 @@ def loads(self, xmldata, decode=True, origxml=None):
11261163

11271164
return self._postamble()
11281165

1129-
def verify(self, key_file="", decrypt=True, pefim=False):
1166+
def verify(self, keys=None):
11301167
try:
11311168
valid_instance(self.response)
11321169
except NotValid as exc:

src/saml2/sigver.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def encrypt(self, text, recv_key, template, session_key_type, xpath=""):
769769
return output
770770

771771
def encrypt_assertion(self, statement, enc_key, template,
772-
key_type="des-192", node_xpath=None):
772+
key_type="des-192", node_xpath=None, node_id=None):
773773
"""
774774
Will encrypt an assertion
775775
@@ -792,6 +792,8 @@ def encrypt_assertion(self, statement, enc_key, template,
792792
com_list = [self.xmlsec, "encrypt", "--pubkey-cert-pem", enc_key,
793793
"--session-key", key_type, "--xml-data", fil,
794794
"--node-xpath", node_xpath]
795+
if node_id:
796+
com_list.extend(["--node-id", node_id])
795797

796798
(_stdout, _stderr, output) = self._run_xmlsec(
797799
com_list, [tmpl], exception=EncryptError, validate_output=False)
@@ -1300,22 +1302,38 @@ def encrypt_assertion(self, statement, enc_key, template,
13001302
"""
13011303
raise NotImplemented()
13021304

1303-
def decrypt(self, enctext, key_file=None):
1305+
def decrypt_keys(self, enctext, keys=None):
13041306
""" Decrypting an encrypted text by the use of a private key.
13051307
13061308
:param enctext: The encrypted text as a string
13071309
:return: The decrypted text
13081310
"""
1311+
if not isinstance(keys, list):
1312+
keys = [keys]
13091313
_enctext = self.crypto.decrypt(enctext, self.key_file)
13101314
if _enctext is not None and len(_enctext) > 0:
13111315
return _enctext
1312-
if key_file is not None and len(key_file.strip()) > 0:
1313-
_enctext = self.crypto.decrypt(enctext, key_file)
1314-
if _enctext is not None and len(_enctext) > 0:
1315-
return _enctext
1316+
for _key in keys:
1317+
if _key is not None and len(_key.strip()) > 0:
1318+
_, key_file = make_temp("%s" % _key, decode=False)
1319+
_enctext = self.crypto.decrypt(enctext, key_file)
1320+
if _enctext is not None and len(_enctext) > 0:
1321+
return _enctext
1322+
return enctext
1323+
1324+
def decrypt(self, enctext, key_file=None):
1325+
""" Decrypting an encrypted text by the use of a private key.
1326+
1327+
:param enctext: The encrypted text as a string
1328+
:return: The decrypted text
1329+
"""
13161330
_enctext = self.crypto.decrypt(enctext, self.key_file)
13171331
if _enctext is not None and len(_enctext) > 0:
13181332
return _enctext
1333+
if key_file is not None and len(key_file.strip()) > 0:
1334+
_enctext = self.crypto.decrypt(enctext, key_file)
1335+
if _enctext is not None and len(_enctext) > 0:
1336+
return _enctext
13191337
return enctext
13201338

13211339
def verify_signature(self, signedtext, cert_file=None, cert_type="pem",

0 commit comments

Comments
 (0)