@@ -395,7 +395,7 @@ def _verify(self):
395
395
def loads (self , xmldata , decode = True , origxml = None ):
396
396
return self ._loads (xmldata , decode , origxml )
397
397
398
- def verify (self , key_file = "" , decrypt = True , pefim = False ):
398
+ def verify (self , keys = None ):
399
399
try :
400
400
return self ._verify ()
401
401
except AssertionError :
@@ -636,18 +636,19 @@ def get_identity(self):
636
636
637
637
"""
638
638
ava = {}
639
- if self .assertion .advice :
640
- if self .assertion .advice .assertion :
641
- for tmp_assertion in self .assertion .advice .assertion :
642
- if tmp_assertion .attribute_statement :
643
- assert len (tmp_assertion .attribute_statement ) == 1
644
- ava .update (self .read_attribute_statement (tmp_assertion .attribute_statement [0 ]))
645
- if self .assertion .attribute_statement :
646
- assert len (self .assertion .attribute_statement ) == 1
647
- _attr_statem = self .assertion .attribute_statement [0 ]
648
- ava .update (self .read_attribute_statement (_attr_statem ))
649
- if not ava :
650
- logger .error ("Missing Attribute Statement" )
639
+ for _assertion in self .assertions :
640
+ if _assertion .advice :
641
+ if _assertion .advice .assertion :
642
+ for tmp_assertion in _assertion .advice .assertion :
643
+ if tmp_assertion .attribute_statement :
644
+ assert len (tmp_assertion .attribute_statement ) == 1
645
+ ava .update (self .read_attribute_statement (tmp_assertion .attribute_statement [0 ]))
646
+ if _assertion .attribute_statement :
647
+ assert len (_assertion .attribute_statement ) == 1
648
+ _attr_statem = _assertion .attribute_statement [0 ]
649
+ ava .update (self .read_attribute_statement (_attr_statem ))
650
+ if not ava :
651
+ logger .error ("Missing Attribute Statement" )
651
652
return ava
652
653
653
654
def _bearer_confirmed (self , data ):
@@ -796,14 +797,14 @@ def _assertion(self, assertion, verified=False):
796
797
logger .exception ("get subject" )
797
798
raise
798
799
799
- def decrypt_assertions (self , encrypted_assertions , decr_txt , issuer = None ):
800
+ def decrypt_assertions (self , encrypted_assertions , decr_txt , issuer = None , verified = False ):
800
801
res = []
801
802
for encrypted_assertion in encrypted_assertions :
802
803
if encrypted_assertion .extension_elements :
803
804
assertions = extension_elements_to_elements (
804
805
encrypted_assertion .extension_elements , [saml , samlp ])
805
806
for assertion in assertions :
806
- if assertion .signature :
807
+ if assertion .signature and not verified :
807
808
if not self .sec .check_signature (
808
809
assertion , origdoc = decr_txt ,
809
810
node_name = class_name (assertion ), issuer = issuer ):
@@ -812,7 +813,35 @@ def decrypt_assertions(self, encrypted_assertions, decr_txt, issuer=None):
812
813
res .append (assertion )
813
814
return res
814
815
815
- def parse_assertion (self , key_file = "" , decrypt = True , pefim = False ):
816
+ def find_encrypt_data_assertion (self , enc_assertions ):
817
+ for _assertion in enc_assertions :
818
+ if _assertion .encrypted_data is not None :
819
+ return True
820
+
821
+ def find_encrypt_data_assertion_list (self , _assertions ):
822
+ for _assertion in _assertions :
823
+ if _assertion .advice :
824
+ if _assertion .advice .encrypted_assertion :
825
+ res = self .find_encrypt_data_assertion (_assertion .advice .encrypted_assertion )
826
+ if res :
827
+ return True
828
+
829
+ def find_encrypt_data (self , resp ):
830
+ _has_encrypt_data = False
831
+ if resp .encrypted_assertion :
832
+ res = self .find_encrypt_data_assertion (resp .encrypted_assertion )
833
+ if res :
834
+ return True
835
+ if resp .assertion :
836
+ for tmp_assertion in resp .assertion :
837
+ if tmp_assertion .advice :
838
+ if tmp_assertion .advice .encrypted_assertion :
839
+ res = self .find_encrypt_data_assertion (tmp_assertion .advice .encrypted_assertion )
840
+ if res :
841
+ return True
842
+ return False
843
+
844
+ def parse_assertion (self , keys = None ):
816
845
if self .context == "AuthnQuery" :
817
846
# can contain one or more assertions
818
847
pass
@@ -823,30 +852,39 @@ def parse_assertion(self, key_file="", decrypt=True, pefim=False):
823
852
except AssertionError :
824
853
raise Exception ("No assertion part" )
825
854
826
- has_encrypted_assertions = self .response .encrypted_assertion
827
- if not has_encrypted_assertions and self .response .assertion :
828
- for tmp_assertion in self .response .assertion :
829
- if tmp_assertion .advice :
830
- if tmp_assertion .advice .encrypted_assertion :
831
- has_encrypted_assertions = True
832
- break
855
+ has_encrypted_assertions = self .find_encrypt_data ( self . response ) #self. response.encrypted_assertion
856
+ # if not has_encrypted_assertions and self.response.assertion:
857
+ # for tmp_assertion in self.response.assertion:
858
+ # if tmp_assertion.advice:
859
+ # if tmp_assertion.advice.encrypted_assertion:
860
+ # has_encrypted_assertions = True
861
+ # break
833
862
834
863
if self .response .assertion :
835
864
logger .debug ("***Unencrypted assertion***" )
836
865
for assertion in self .response .assertion :
837
866
if not self ._assertion (assertion , False ):
838
867
return False
839
868
840
- if has_encrypted_assertions and decrypt :
869
+ if has_encrypted_assertions :
841
870
_enc_assertions = []
842
871
logger .debug ("***Encrypted assertion/-s***" )
843
- decr_text = self .sec .decrypt (self .xmlstr , key_file )
844
- resp = samlp .response_from_string (decr_text )
872
+ decr_text = "%s" % self .response
873
+ resp = self .response
874
+ while self .find_encrypt_data (resp ):
875
+ decr_text = self .sec .decrypt_keys (decr_text , keys )
876
+ resp = samlp .response_from_string (decr_text )
845
877
_enc_assertions = self .decrypt_assertions (resp .encrypted_assertion , decr_text )
846
- decr_text = self .sec .decrypt (decr_text , key_file )
847
- resp = samlp .response_from_string (decr_text )
878
+ while self .find_encrypt_data (resp ) or self .find_encrypt_data_assertion_list (_enc_assertions ):
879
+ decr_text = self .sec .decrypt_keys (decr_text , keys )
880
+ resp = samlp .response_from_string (decr_text )
881
+ _enc_assertions = self .decrypt_assertions (resp .encrypted_assertion , decr_text , verified = True )
882
+ #_enc_assertions = self.decrypt_assertions(resp.encrypted_assertion, decr_text, verified=True)
883
+ all_assertions = _enc_assertions
848
884
if resp .assertion :
849
- for tmp_ass in resp .assertion :
885
+ all_assertions = all_assertions + resp .assertion
886
+ if len (all_assertions ) > 0 :
887
+ for tmp_ass in all_assertions :
850
888
if tmp_ass .advice and tmp_ass .advice .encrypted_assertion :
851
889
advice_res = self .decrypt_assertions (tmp_ass .advice .encrypted_assertion ,
852
890
decr_text ,
@@ -855,21 +893,20 @@ def parse_assertion(self, key_file="", decrypt=True, pefim=False):
855
893
tmp_ass .advice .assertion .extend (advice_res )
856
894
else :
857
895
tmp_ass .advice .assertion = advice_res
858
- if not pefim :
859
- _enc_assertions .extend (advice_res )
860
896
tmp_ass .advice .encrypted_assertion = []
861
897
self .response .assertion = resp .assertion
862
898
for assertion in _enc_assertions :
863
899
if not self ._assertion (assertion , True ):
864
900
return False
901
+ else :
902
+ self .assertions .append (assertion )
903
+
865
904
self .xmlstr = decr_text
866
905
self .response .encrypted_assertion = []
867
906
868
907
if self .response .assertion :
869
908
for assertion in self .response .assertion :
870
- if assertion .advice and assertion .advice .assertion :
871
- for advice_assertion in assertion .advice .assertion :
872
- self .assertions .append (assertion )
909
+ self .assertions .append (assertion )
873
910
874
911
if self .assertions and len (self .assertions ) > 0 :
875
912
self .assertion = self .assertions [0 ]
@@ -880,7 +917,7 @@ def parse_assertion(self, key_file="", decrypt=True, pefim=False):
880
917
881
918
return True
882
919
883
- def verify (self , key_file = "" , decrypt = True , pefim = False ):
920
+ def verify (self , keys = None ):
884
921
""" Verify that the assertion is syntactically correct and
885
922
the signature is correct if present.
886
923
:param key_file: If not the default key file should be used this is it.
@@ -898,7 +935,7 @@ def verify(self, key_file="", decrypt=True, pefim=False):
898
935
if not isinstance (self .response , samlp .Response ):
899
936
return self
900
937
901
- if self .parse_assertion (key_file , decrypt = decrypt , pefim = pefim ):
938
+ if self .parse_assertion (keys ):
902
939
return self
903
940
else :
904
941
logger .error ("Could not parse the assertion" )
@@ -1126,7 +1163,7 @@ def loads(self, xmldata, decode=True, origxml=None):
1126
1163
1127
1164
return self ._postamble ()
1128
1165
1129
- def verify (self , key_file = "" , decrypt = True , pefim = False ):
1166
+ def verify (self , keys = None ):
1130
1167
try :
1131
1168
valid_instance (self .response )
1132
1169
except NotValid as exc :
0 commit comments