45
45
from saml2 .saml import SCM_SENDER_VOUCHES
46
46
from saml2 .saml import encrypted_attribute_from_string
47
47
from saml2 .sigver import security_context
48
+ from saml2 .sigver import DecryptError
48
49
from saml2 .sigver import SignatureError
49
50
from saml2 .sigver import signed
50
51
from saml2 .attribute_converter import to_local
@@ -896,7 +897,6 @@ def find_encrypt_data(self, resp):
896
897
:param resp: A saml response.
897
898
:return: True encrypted data exists otherwise false.
898
899
"""
899
- _has_encrypt_data = False
900
900
if resp .encrypted_assertion :
901
901
res = self .find_encrypt_data_assertion (resp .encrypted_assertion )
902
902
if res :
@@ -921,56 +921,58 @@ def parse_assertion(self, keys=None):
921
921
if self .context == "AuthnQuery" :
922
922
# can contain one or more assertions
923
923
pass
924
- else : # This is a saml2int limitation
924
+ else :
925
+ # This is a saml2int limitation
925
926
try :
926
- assert len (self .response .assertion ) == 1 or \
927
- len (self .response .encrypted_assertion ) == 1 or \
928
- self .assertion is not None
927
+ assert (
928
+ len (self .response .assertion ) == 1
929
+ or len (self .response .encrypted_assertion ) == 1
930
+ or self .assertion is not None
931
+ )
929
932
except AssertionError :
930
933
raise Exception ("No assertion part" )
931
934
932
- has_encrypted_assertions = self .find_encrypt_data (self .response ) #
933
- # self.response.encrypted_assertion
934
- # if not has_encrypted_assertions and self.response.assertion:
935
- # for tmp_assertion in self.response.assertion:
936
- # if tmp_assertion.advice:
937
- # if tmp_assertion.advice.encrypted_assertion:
938
- # has_encrypted_assertions = True
939
- # break
940
-
941
935
if self .response .assertion :
942
936
logger .debug ("***Unencrypted assertion***" )
943
937
for assertion in self .response .assertion :
944
938
if not self ._assertion (assertion , False ):
945
939
return False
946
940
947
- if has_encrypted_assertions :
948
- _enc_assertions = []
941
+ if self .find_encrypt_data (self .response ):
949
942
logger .debug ("***Encrypted assertion/-s***" )
950
- decr_text = "%s" % self . response
943
+ _enc_assertions = []
951
944
resp = self .response
952
- decr_text_old = None
953
- while self .find_encrypt_data (resp ) and decr_text_old != decr_text :
954
- decr_text_old = decr_text
955
- decr_text = self .sec .decrypt_keys (decr_text , keys )
956
- resp = samlp .response_from_string (decr_text )
957
- _enc_assertions = self .decrypt_assertions (resp .encrypted_assertion ,
958
- decr_text )
959
- decr_text_old = None
960
- while (self .find_encrypt_data (
961
- resp ) or self .find_encrypt_data_assertion_list (
962
- _enc_assertions )) and \
963
- decr_text_old != decr_text :
964
- decr_text_old = decr_text
965
- decr_text = self .sec .decrypt_keys (decr_text , keys )
966
- resp = samlp .response_from_string (decr_text )
967
- _enc_assertions = self .decrypt_assertions (
968
- resp .encrypted_assertion , decr_text , verified = True )
969
- # _enc_assertions = self.decrypt_assertions(
970
- # resp.encrypted_assertion, decr_text, verified=True)
945
+ decr_text = str (self .response )
946
+
947
+ while self .find_encrypt_data (resp ):
948
+ try :
949
+ decr_text = self .sec .decrypt_keys (decr_text , keys )
950
+ except DecryptError as e :
951
+ continue
952
+ else :
953
+ resp = samlp .response_from_string (decr_text )
954
+
955
+ _enc_assertions = self .decrypt_assertions (
956
+ resp .encrypted_assertion , decr_text
957
+ )
958
+ while (
959
+ self .find_encrypt_data (resp )
960
+ or self .find_encrypt_data_assertion_list (_enc_assertions )
961
+ ):
962
+ try :
963
+ decr_text = self .sec .decrypt_keys (decr_text , keys )
964
+ except DecryptError as e :
965
+ continue
966
+ else :
967
+ resp = samlp .response_from_string (decr_text )
968
+ _enc_assertions = self .decrypt_assertions (
969
+ resp .encrypted_assertion , decr_text , verified = True
970
+ )
971
+
971
972
all_assertions = _enc_assertions
972
973
if resp .assertion :
973
974
all_assertions = all_assertions + resp .assertion
975
+
974
976
if len (all_assertions ) > 0 :
975
977
for tmp_ass in all_assertions :
976
978
if tmp_ass .advice and tmp_ass .advice .encrypted_assertion :
@@ -985,6 +987,7 @@ def parse_assertion(self, keys=None):
985
987
tmp_ass .advice .assertion = advice_res
986
988
if len (advice_res ) > 0 :
987
989
tmp_ass .advice .encrypted_assertion = []
990
+
988
991
self .response .assertion = resp .assertion
989
992
for assertion in _enc_assertions :
990
993
if not self ._assertion (assertion , True ):
0 commit comments