@@ -628,7 +628,7 @@ def condition_ok(self, lax=False):
628
628
629
629
return True
630
630
631
- def decrypt_attributes (self , attribute_statement ):
631
+ def decrypt_attributes (self , attribute_statement , keys = None ):
632
632
"""
633
633
Decrypts possible encrypted attributes and adds the decrypts to the
634
634
list of attributes.
@@ -642,11 +642,11 @@ def decrypt_attributes(self, attribute_statement):
642
642
643
643
for encattr in attribute_statement .encrypted_attribute :
644
644
if not encattr .encrypted_key :
645
- _decr = self .sec .decrypt (encattr .encrypted_data )
645
+ _decr = self .sec .decrypt_keys (encattr .encrypted_data , keys = keys )
646
646
_attr = attribute_from_string (_decr )
647
647
attribute_statement .attribute .append (_attr )
648
648
else :
649
- _decr = self .sec .decrypt (encattr )
649
+ _decr = self .sec .decrypt_keys (encattr , keys = keys )
650
650
enc_attr = encrypted_attribute_from_string (_decr )
651
651
attrlist = enc_attr .extensions_as_elements ("Attribute" , saml )
652
652
attribute_statement .attribute .extend (attrlist )
@@ -734,7 +734,7 @@ def _holder_of_key_confirmed(self, data):
734
734
735
735
return has_keyinfo
736
736
737
- def get_subject (self ):
737
+ def get_subject (self , keys = None ):
738
738
""" The assertion must contain a Subject
739
739
"""
740
740
@@ -785,8 +785,9 @@ def get_subject(self):
785
785
self .name_id = subject .name_id
786
786
elif subject .encrypted_id :
787
787
# decrypt encrypted ID
788
- _name_id_str = self .sec .decrypt (
789
- subject .encrypted_id .encrypted_data .to_string ())
788
+ _name_id_str = self .sec .decrypt_keys (
789
+ subject .encrypted_id .encrypted_data .to_string (), keys = keys
790
+ )
790
791
_name_id = saml .name_id_from_string (_name_id_str )
791
792
self .name_id = _name_id
792
793
@@ -958,7 +959,7 @@ def parse_assertion(self, keys=None):
958
959
while self .find_encrypt_data (resp ) and decr_text_old != decr_text :
959
960
decr_text_old = decr_text
960
961
try :
961
- decr_text = self .sec .decrypt_keys (decr_text , keys )
962
+ decr_text = self .sec .decrypt_keys (decr_text , keys = keys )
962
963
except DecryptError as e :
963
964
continue
964
965
else :
@@ -981,7 +982,7 @@ def parse_assertion(self, keys=None):
981
982
) and decr_text_old != decr_text :
982
983
decr_text_old = decr_text
983
984
try :
984
- decr_text = self .sec .decrypt_keys (decr_text , keys )
985
+ decr_text = self .sec .decrypt_keys (decr_text , keys = keys )
985
986
except DecryptError as e :
986
987
continue
987
988
else :
0 commit comments