|
1 | 1 | import base64
|
2 | 2 | #from binascii import hexlify
|
| 3 | +import copy |
3 | 4 | import logging
|
4 | 5 | from hashlib import sha1
|
5 | 6 | from Crypto.PublicKey import RSA
|
|
37 | 38 | from saml2.s_utils import success_status_factory
|
38 | 39 | from saml2.s_utils import decode_base64_and_inflate
|
39 | 40 | from saml2.s_utils import UnsupportedBinding
|
40 |
| -from saml2.samlp import AuthnRequest, SessionIndex |
| 41 | +from saml2.samlp import AuthnRequest, SessionIndex, response_from_string |
41 | 42 | from saml2.samlp import AuthzDecisionQuery
|
42 | 43 | from saml2.samlp import AuthnQuery
|
43 | 44 | from saml2.samlp import AssertionIDRequest
|
@@ -502,10 +503,46 @@ def _add_info(self, msg, **kwargs):
|
502 | 503 | else:
|
503 | 504 | msg.extension_elements = extensions
|
504 | 505 |
|
| 506 | + def has_encrypt_cert_in_metadata(self, sp_entity_id): |
| 507 | + if sp_entity_id is not None: |
| 508 | + _certs = self.metadata.certs(sp_entity_id, "any", "encryption") |
| 509 | + if len(_certs) > 0: |
| 510 | + return True |
| 511 | + return False |
| 512 | + |
| 513 | + |
| 514 | + def _encrypt_assertion(self, encrypt_cert, sp_entity_id, response, node_xpath=None): |
| 515 | + _certs = [] |
| 516 | + cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary) |
| 517 | + if encrypt_cert: |
| 518 | + _certs = [] |
| 519 | + _certs.append(encrypt_cert) |
| 520 | + elif sp_entity_id is not None: |
| 521 | + _certs = self.metadata.certs(sp_entity_id, "any", "encryption") |
| 522 | + exception = None |
| 523 | + for _cert in _certs: |
| 524 | + try: |
| 525 | + begin_cert = "-----BEGIN CERTIFICATE-----\n" |
| 526 | + end_cert = "\n-----END CERTIFICATE-----\n" |
| 527 | + if begin_cert not in _cert: |
| 528 | + _cert = "%s%s" % (begin_cert, _cert) |
| 529 | + if end_cert not in _cert: |
| 530 | + _cert = "%s%s" % (_cert, end_cert) |
| 531 | + _, cert_file = make_temp(_cert, decode=False) |
| 532 | + response = cbxs.encrypt_assertion(response, cert_file, |
| 533 | + pre_encryption_part(), node_xpath=node_xpath) |
| 534 | + return response |
| 535 | + except Exception as ex: |
| 536 | + exception = ex |
| 537 | + pass |
| 538 | + if exception: |
| 539 | + raise exception |
| 540 | + return response |
| 541 | + |
505 | 542 | def _response(self, in_response_to, consumer_url=None, status=None,
|
506 |
| - issuer=None, sign=False, to_sign=None, |
| 543 | + issuer=None, sign=False, to_sign=None, sp_entity_id=None, |
507 | 544 | encrypt_assertion=False, encrypt_assertion_self_contained=False, encrypted_advice_attributes=False,
|
508 |
| - encrypt_cert=None, **kwargs): |
| 545 | + encrypt_cert_advice=None, encrypt_cert_assertion=None,sign_assertion=None, pefim=False, **kwargs): |
509 | 546 | """ Create a Response.
|
510 | 547 | Encryption:
|
511 | 548 | encrypt_assertion must be true for encryption to be performed. If encrypted_advice_attributes also is
|
@@ -542,43 +579,79 @@ def _response(self, in_response_to, consumer_url=None, status=None,
|
542 | 579 | if not sign and to_sign and not encrypt_assertion:
|
543 | 580 | return signed_instance_factory(response, self.sec, to_sign)
|
544 | 581 |
|
545 |
| - if encrypt_assertion: |
546 |
| - node_xpath = None |
| 582 | + has_encrypt_cert = self.has_encrypt_cert_in_metadata(sp_entity_id) |
| 583 | + if not has_encrypt_cert and encrypt_cert_advice is None: |
| 584 | + encrypted_advice_attributes = False |
| 585 | + if not has_encrypt_cert and encrypt_cert_assertion is None: |
| 586 | + encrypt_assertion = False |
| 587 | + |
| 588 | + if encrypt_assertion or (encrypted_advice_attributes and response.assertion.advice is not None and |
| 589 | + len(response.assertion.advice.assertion) == 1): |
547 | 590 | if sign:
|
548 | 591 | response.signature = pre_signature_part(response.id,
|
549 | 592 | self.sec.my_cert, 1)
|
550 | 593 | sign_class = [(class_name(response), response.id)]
|
551 | 594 | cbxs = CryptoBackendXmlSec1(self.config.xmlsec_binary)
|
| 595 | + encrypt_advice = False |
552 | 596 | if encrypted_advice_attributes and response.assertion.advice is not None \
|
553 |
| - and len(response.assertion.advice.assertion) == 1: |
554 |
| - tmp_assertion = response.assertion.advice.assertion[0] |
555 |
| - response.assertion.advice.encrypted_assertion = [] |
556 |
| - response.assertion.advice.encrypted_assertion.append(EncryptedAssertion()) |
557 |
| - if isinstance(tmp_assertion, list): |
558 |
| - response.assertion.advice.encrypted_assertion[0].add_extension_elements(tmp_assertion) |
559 |
| - else: |
560 |
| - response.assertion.advice.encrypted_assertion[0].add_extension_element(tmp_assertion) |
561 |
| - response.assertion.advice.assertion = [] |
| 597 | + and len(response.assertion.advice.assertion) > 0: |
| 598 | + _assertions = response.assertion |
| 599 | + if not isinstance(_assertions, list): |
| 600 | + _assertions = [_assertions] |
| 601 | + for _assertion in _assertions: |
| 602 | + _assertion.advice.encrypted_assertion = [] |
| 603 | + _assertion.advice.encrypted_assertion.append(EncryptedAssertion()) |
| 604 | + _advice_assertions = copy.deepcopy(_assertion.advice.assertion) |
| 605 | + _assertion.advice.assertion = [] |
| 606 | + if not isinstance(_advice_assertions, list): |
| 607 | + _advice_assertions = [_advice_assertions] |
| 608 | + for tmp_assertion in _advice_assertions: |
| 609 | + to_sign_advice = [] |
| 610 | + if sign_assertion and not pefim: |
| 611 | + tmp_assertion.signature = pre_signature_part(tmp_assertion.id, self.sec.my_cert, 1) |
| 612 | + to_sign_advice.append((class_name(tmp_assertion), tmp_assertion.id)) |
| 613 | + #tmp_assertion = response.assertion.advice.assertion[0] |
| 614 | + _assertion.advice.encrypted_assertion[0].add_extension_element(tmp_assertion) |
| 615 | + |
| 616 | + if encrypt_assertion_self_contained: |
| 617 | + advice_tag = response.assertion.advice._to_element_tree().tag |
| 618 | + assertion_tag = tmp_assertion._to_element_tree().tag |
| 619 | + response = \ |
| 620 | + response.get_xml_string_with_self_contained_assertion_within_advice_encrypted_assertion( |
| 621 | + assertion_tag, advice_tag) |
| 622 | + node_xpath = ''.join(["/*[local-name()=\"%s\"]" % v for v in |
| 623 | + ["Response", "Assertion", "Advice", "EncryptedAssertion", "Assertion"]]) |
| 624 | + |
| 625 | + if to_sign_advice: |
| 626 | + response = signed_instance_factory(response, self.sec, to_sign_advice) |
| 627 | + response = self._encrypt_assertion(encrypt_cert_advice, sp_entity_id, response, node_xpath=node_xpath) |
| 628 | + response = response_from_string(response) |
| 629 | + |
| 630 | + if encrypt_assertion: |
| 631 | + to_sign_assertion = [] |
| 632 | + if sign_assertion is not None and sign_assertion: |
| 633 | + _assertions = response.assertion |
| 634 | + if not isinstance(_assertions, list): |
| 635 | + _assertions = [_assertions] |
| 636 | + for _assertion in _assertions: |
| 637 | + _assertion.signature = pre_signature_part(_assertion.id, self.sec.my_cert, 1) |
| 638 | + to_sign_assertion.append((class_name(_assertion), _assertion.id)) |
562 | 639 | if encrypt_assertion_self_contained:
|
563 |
| - advice_tag = response.assertion.advice._to_element_tree().tag |
564 |
| - assertion_tag = tmp_assertion._to_element_tree().tag |
565 |
| - response = response.get_xml_string_with_self_contained_assertion_within_advice_encrypted_assertion( |
566 |
| - assertion_tag, advice_tag) |
567 |
| - node_xpath = ''.join(["/*[local-name()=\"%s\"]" % v for v in |
568 |
| - ["Response", "Assertion", "Advice", "EncryptedAssertion", "Assertion"]]) |
569 |
| - elif encrypt_assertion_self_contained: |
570 |
| - assertion_tag = response.assertion._to_element_tree().tag |
571 |
| - response = pre_encrypt_assertion(response) |
572 |
| - response = response.get_xml_string_with_self_contained_assertion_within_encrypted_assertion( |
573 |
| - assertion_tag) |
| 640 | + try: |
| 641 | + assertion_tag = response.assertion._to_element_tree().tag |
| 642 | + except: |
| 643 | + assertion_tag = response.assertion[0]._to_element_tree().tag |
| 644 | + response = pre_encrypt_assertion(response) |
| 645 | + response = response.get_xml_string_with_self_contained_assertion_within_encrypted_assertion( |
| 646 | + assertion_tag) |
| 647 | + else: |
| 648 | + response = pre_encrypt_assertion(response) |
| 649 | + if to_sign_assertion: |
| 650 | + response = signed_instance_factory(response, self.sec, to_sign_assertion) |
| 651 | + response = self._encrypt_assertion(encrypt_cert_assertion, sp_entity_id, response) |
574 | 652 | else:
|
575 |
| - response = pre_encrypt_assertion(response) |
576 |
| - if to_sign: |
577 |
| - response = signed_instance_factory(response, self.sec, to_sign) |
578 |
| - _, cert_file = make_temp("%s" % encrypt_cert, decode=False) |
579 |
| - response = cbxs.encrypt_assertion(response, cert_file, |
580 |
| - pre_encryption_part(), node_xpath=node_xpath) |
581 |
| - # template(response.assertion.id)) |
| 653 | + if to_sign: |
| 654 | + response = signed_instance_factory(response, self.sec, to_sign) |
582 | 655 | if sign:
|
583 | 656 | return signed_instance_factory(response, self.sec, sign_class)
|
584 | 657 | else:
|
@@ -968,23 +1041,23 @@ def _parse_response(self, xmlstr, response_cls, service, binding,
|
968 | 1041 | logger.debug("XMLSTR: %s" % xmlstr)
|
969 | 1042 |
|
970 | 1043 | if response:
|
| 1044 | + keys = None |
971 | 1045 | if outstanding_certs:
|
972 | 1046 | try:
|
973 | 1047 | cert = outstanding_certs[response.in_response_to]
|
974 | 1048 | except KeyError:
|
975 |
| - key_file = "" |
| 1049 | + keys = None |
976 | 1050 | else:
|
977 |
| - _, key_file = make_temp("%s" % cert["key"], |
978 |
| - decode=False) |
979 |
| - else: |
980 |
| - key_file = "" |
| 1051 | + if not isinstance(cert, list): |
| 1052 | + cert = [cert] |
| 1053 | + keys = [] |
| 1054 | + for _cert in cert: |
| 1055 | + keys.append(_cert["key"]) |
981 | 1056 | only_identity_in_encrypted_assertion = False
|
982 | 1057 | if "only_identity_in_encrypted_assertion" in kwargs:
|
983 | 1058 | only_identity_in_encrypted_assertion = kwargs["only_identity_in_encrypted_assertion"]
|
984 |
| - decrypt = True |
985 |
| - if "decrypt" in kwargs: |
986 |
| - decrypt = kwargs["decrypt"] |
987 |
| - response = response.verify(key_file, decrypt=decrypt) |
| 1059 | + |
| 1060 | + response = response.verify(keys) |
988 | 1061 |
|
989 | 1062 | if not response:
|
990 | 1063 | return None
|
|
0 commit comments