Skip to content

Commit a12cc2a

Browse files
author
Roland Hedberg
committed
Fixed various errors.
1 parent dfad000 commit a12cc2a

File tree

6 files changed

+126
-54
lines changed

6 files changed

+126
-54
lines changed

src/saml2/client.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,14 @@ def do_logout(self, name_id, entity_ids, reason, expire, sign=None,
206206

207207
destination = destinations(srvs)[0]
208208
logger.info("destination to provider: %s" % destination)
209+
try:
210+
session_info = self.users.get_info_from(name_id, entity_id)
211+
session_indexes = [session_info['session_index']]
212+
except KeyError:
213+
session_indexes = None
209214
req_id, request = self.create_logout_request(
210215
destination, entity_id, name_id=name_id, reason=reason,
211-
expire=expire)
216+
expire=expire, session_indexes=session_indexes)
212217

213218
# to_sign = []
214219
if binding.startswith("http://"):

src/saml2/metadata.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,17 @@
5959
XMLNSXS = " xmlns:xs=\"http://www.w3.org/2001/XMLSchema\""
6060
bXMLNSXS = b" xmlns:xs=\"http://www.w3.org/2001/XMLSchema\""
6161

62+
6263
def metadata_tostring_fix(desc, nspair, xmlstring=""):
6364
if not xmlstring:
6465
xmlstring = desc.to_string(nspair)
6566

6667
if six.PY2:
6768
if "\"xs:string\"" in xmlstring and XMLNSXS not in xmlstring:
68-
xmlstring = xmlstring.replace(MDNS, MDNS+XMLNSXS)
69+
xmlstring = xmlstring.replace(MDNS, MDNS + XMLNSXS)
6970
else:
7071
if b"\"xs:string\"" in xmlstring and bXMLNSXS not in xmlstring:
71-
xmlstring = xmlstring.replace(bMDNS, bMDNS+bXMLNSXS)
72+
xmlstring = xmlstring.replace(bMDNS, bMDNS + bXMLNSXS)
7273

7374
return xmlstring
7475

@@ -77,7 +78,7 @@ def create_metadata_string(configfile, config=None, valid=None, cert=None,
7778
keyfile=None, mid=None, name=None, sign=None):
7879
valid_for = 0
7980
nspair = {"xs": "http://www.w3.org/2001/XMLSchema"}
80-
#paths = [".", "/opt/local/bin"]
81+
# paths = [".", "/opt/local/bin"]
8182

8283
if valid:
8384
valid_for = int(valid) # Hours
@@ -97,21 +98,17 @@ def create_metadata_string(configfile, config=None, valid=None, cert=None,
9798
secc = security_context(conf)
9899

99100
if mid:
100-
desc = entities_descriptor(eds, valid_for, name, mid,
101-
sign, secc)
102-
valid_instance(desc)
103-
104-
return metadata_tostring_fix(desc, nspair)
101+
eid, xmldoc = entities_descriptor(eds, valid_for, name, mid,
102+
sign, secc)
105103
else:
106104
eid = eds[0]
107105
if sign:
108106
eid, xmldoc = sign_entity_descriptor(eid, mid, secc)
109107
else:
110108
xmldoc = None
111109

112-
valid_instance(eid)
113-
xmldoc = metadata_tostring_fix(eid, nspair, xmldoc)
114-
return xmldoc
110+
valid_instance(eid)
111+
return metadata_tostring_fix(eid, nspair, xmldoc)
115112

116113

117114
def _localized_name(val, klass):
@@ -346,6 +343,7 @@ def do_idpdisc(discovery_response):
346343
return idpdisc.DiscoveryResponse(index="0", location=discovery_response,
347344
binding=idpdisc.NAMESPACE)
348345

346+
349347
ENDPOINTS = {
350348
"sp": {
351349
"artifact_resolution_service": (md.ArtifactResolutionService, True),
@@ -425,7 +423,8 @@ def do_endpoints(conf, endpoints):
425423
servs = []
426424
i = 1
427425
for args in conf[endpoint]:
428-
if isinstance(args, six.string_types): # Assume it's the location
426+
if isinstance(args,
427+
six.string_types): # Assume it's the location
429428
args = {"location": args,
430429
"binding": DEFAULT_BINDING[endpoint]}
431430
elif isinstance(args, tuple) or isinstance(args, list):
@@ -453,16 +452,16 @@ def do_endpoints(conf, endpoints):
453452
pass
454453
return service
455454

455+
456456
DEFAULT = {
457457
"want_assertions_signed": "true",
458458
"authn_requests_signed": "false",
459459
"want_authn_requests_signed": "false",
460-
#"want_authn_requests_only_with_valid_cert": "false",
460+
# "want_authn_requests_only_with_valid_cert": "false",
461461
}
462462

463463

464464
def do_attribute_consuming_service(conf, spsso):
465-
466465
service_description = service_name = None
467466
requested_attributes = []
468467
acs = conf.attribute_converters
@@ -557,7 +556,8 @@ def do_spsso_descriptor(conf, cert=None, enc_cert=None):
557556

558557
if cert or enc_cert:
559558
metadata_key_usage = conf.metadata_key_usage
560-
spsso.key_descriptor = do_key_descriptor(cert=cert, enc_cert=enc_cert, use=metadata_key_usage)
559+
spsso.key_descriptor = do_key_descriptor(cert=cert, enc_cert=enc_cert,
560+
use=metadata_key_usage)
561561

562562
for key in ["want_assertions_signed", "authn_requests_signed"]:
563563
try:
@@ -605,10 +605,11 @@ def do_idpsso_descriptor(conf, cert=None, enc_cert=None):
605605
idpsso.extensions.add_extension_element(do_uiinfo(ui_info))
606606

607607
if cert or enc_cert:
608-
idpsso.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
608+
idpsso.key_descriptor = do_key_descriptor(cert, enc_cert,
609+
use=conf.metadata_key_usage)
609610

610611
for key in ["want_authn_requests_signed"]:
611-
#"want_authn_requests_only_with_valid_cert"]:
612+
# "want_authn_requests_only_with_valid_cert"]:
612613
try:
613614
val = conf.getattr(key, "idp")
614615
if val is None:
@@ -635,7 +636,8 @@ def do_aa_descriptor(conf, cert=None, enc_cert=None):
635636
_do_nameid_format(aad, conf, "aa")
636637

637638
if cert or enc_cert:
638-
aad.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
639+
aad.key_descriptor = do_key_descriptor(cert, enc_cert,
640+
use=conf.metadata_key_usage)
639641

640642
attributes = conf.getattr("attribute", "aa")
641643
if attributes:
@@ -664,7 +666,8 @@ def do_aq_descriptor(conf, cert=None, enc_cert=None):
664666
_do_nameid_format(aqs, conf, "aq")
665667

666668
if cert or enc_cert:
667-
aqs.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
669+
aqs.key_descriptor = do_key_descriptor(cert, enc_cert,
670+
use=conf.metadata_key_usage)
668671

669672
return aqs
670673

@@ -685,7 +688,8 @@ def do_pdp_descriptor(conf, cert=None, enc_cert=None):
685688
_do_nameid_format(pdp, conf, "pdp")
686689

687690
if cert:
688-
pdp.key_descriptor = do_key_descriptor(cert, enc_cert, use=conf.metadata_key_usage)
691+
pdp.key_descriptor = do_key_descriptor(cert, enc_cert,
692+
use=conf.metadata_key_usage)
689693

690694
return pdp
691695

@@ -702,7 +706,8 @@ def entity_descriptor(confd):
702706
if confd.encryption_keypairs is not None:
703707
enc_cert = []
704708
for _encryption in confd.encryption_keypairs:
705-
enc_cert.append("".join(open(_encryption["cert_file"]).readlines()[1:-1]))
709+
enc_cert.append(
710+
"".join(open(_encryption["cert_file"]).readlines()[1:-1]))
706711

707712
entd = md.EntityDescriptor()
708713
entd.entity_id = confd.entityid
@@ -736,13 +741,15 @@ def entity_descriptor(confd):
736741
entd.idpsso_descriptor = do_idpsso_descriptor(confd, mycert, enc_cert)
737742
if "aa" in serves:
738743
confd.context = "aa"
739-
entd.attribute_authority_descriptor = do_aa_descriptor(confd, mycert, enc_cert)
744+
entd.attribute_authority_descriptor = do_aa_descriptor(confd, mycert,
745+
enc_cert)
740746
if "pdp" in serves:
741747
confd.context = "pdp"
742748
entd.pdp_descriptor = do_pdp_descriptor(confd, mycert, enc_cert)
743749
if "aq" in serves:
744750
confd.context = "aq"
745-
entd.authn_authority_descriptor = do_aq_descriptor(confd, mycert, enc_cert)
751+
entd.authn_authority_descriptor = do_aq_descriptor(confd, mycert,
752+
enc_cert)
746753

747754
return entd
748755

src/saml2/response.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1036,9 +1036,11 @@ def session_info(self):
10361036
"issuer": self.issuer(), "not_on_or_after": nooa,
10371037
"authz_decision_info": self.authz_decision_info()}
10381038
else:
1039+
authn_statement = self.assertion.authn_statement[0]
10391040
return {"ava": self.ava, "name_id": self.name_id,
10401041
"came_from": self.came_from, "issuer": self.issuer(),
1041-
"not_on_or_after": nooa, "authn_info": self.authn_info()}
1042+
"not_on_or_after": nooa, "authn_info": self.authn_info(),
1043+
"session_index": authn_statement.session_index}
10421044

10431045
def __str__(self):
10441046
if not isinstance(self.xmlstr, six.string_types):

src/saml2/sigver.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,10 @@ def make_temp(string, suffix="", decode=True, delete=True):
353353
xmlsec function).
354354
"""
355355
ntf = NamedTemporaryFile(suffix=suffix, delete=delete)
356-
assert isinstance(string, six.binary_type)
356+
# Python3 tempfile requires byte-like object
357+
if not isinstance(string, six.binary_type):
358+
string = string.encode("utf8")
359+
357360
if decode:
358361
ntf.write(base64.b64decode(string))
359362
else:
@@ -657,6 +660,12 @@ def verify_redirect_signature(saml_msg, cert=None, sigkey=None):
657660
LOG_LINE_2 = 60 * "=" + "\n%s\n%s\n" + 60 * "-" + "\n%s" + 60 * "="
658661

659662

663+
def make_str(txt):
664+
if isinstance(txt, six.string_types):
665+
return txt
666+
else:
667+
return txt.decode("utf8")
668+
660669
# ---------------------------------------------------------------------------
661670

662671

@@ -674,29 +683,32 @@ def read_cert_from_file(cert_file, cert_type):
674683
return ""
675684

676685
if cert_type == "pem":
677-
line = open(cert_file).read().replace("\r\n", "\n").split("\n")
678-
679-
if line[0] == "-----BEGIN CERTIFICATE-----":
680-
line = line[1:]
681-
elif line[0] == "-----BEGIN PUBLIC KEY-----":
682-
line = line[1:]
686+
_a = read_file(cert_file, 'rb').decode("utf8")
687+
_b = _a.replace("\r\n", "\n")
688+
lines = _b.split("\n")
689+
690+
for pattern in ("-----BEGIN CERTIFICATE-----",
691+
"-----BEGIN PUBLIC KEY-----"):
692+
if pattern in lines:
693+
lines = lines[lines.index(pattern)+1:]
694+
break
683695
else:
684696
raise CertificateError("Strange beginning of PEM file")
685697

686-
while line[-1] == "":
687-
line = line[:-1]
688-
689-
if line[-1] == "-----END CERTIFICATE-----":
690-
line = line[:-1]
691-
elif line[-1] == "-----END PUBLIC KEY-----":
692-
line = line[:-1]
698+
for pattern in ("-----END CERTIFICATE-----",
699+
"-----END PUBLIC KEY-----"):
700+
if pattern in lines:
701+
lines = lines[:lines.index(pattern)]
702+
break
693703
else:
694704
raise CertificateError("Strange end of PEM file")
695-
return "".join(line)
705+
return make_str("".join(lines).encode("utf8"))
706+
696707

697708
if cert_type in ["der", "cer", "crt"]:
698-
data = read_file(cert_file)
699-
return base64.b64encode(str(data))
709+
data = read_file(cert_file, 'rb')
710+
_cert = base64.b64encode(data)
711+
return make_str(_cert)
700712

701713

702714
class CryptoBackend():
@@ -850,8 +862,8 @@ def sign_statement(self, statement, node_name, key_file, node_id,
850862
'id','Id' or 'ID'
851863
:return: The signed statement
852864
"""
853-
if not isinstance(statement, six.binary_type):
854-
statement = str(statement).encode('utf-8')
865+
if isinstance(statement, SamlBase):
866+
statement = str(statement)
855867

856868
_, fil = make_temp(statement, suffix=".xml",
857869
decode=False, delete=self._xmlsec_delete_tmpfiles)
@@ -1284,8 +1296,6 @@ def __init__(self, crypto, key_file="", key_type="pem",
12841296
self.encryption_keypairs = encryption_keypairs
12851297
self.enc_cert_type = enc_cert_type
12861298

1287-
1288-
12891299
self.my_cert = read_cert_from_file(cert_file, cert_type)
12901300

12911301
self.cert_handler = CertHandler(self, cert_file, cert_type, key_file,

tests/test_40_sigver.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
#!/usr/bin/env python
22

33
import base64
4-
from saml2.sigver import pre_encryption_part, make_temp, XmlsecError, \
5-
SigverError
6-
from saml2.mdstore import MetadataStore
7-
from saml2.saml import assertion_from_string, EncryptedAssertion
8-
from saml2.samlp import response_from_string
9-
10-
from saml2 import sigver, extension_elements_to_elements
4+
from saml2 import sigver
5+
from saml2 import extension_elements_to_elements
116
from saml2 import class_name
127
from saml2 import time_util
138
from saml2 import saml, samlp
149
from saml2 import config
10+
from saml2.sigver import pre_encryption_part
11+
from saml2.sigver import make_temp
12+
from saml2.sigver import XmlsecError
13+
from saml2.sigver import SigverError
14+
from saml2.mdstore import MetadataStore
15+
from saml2.saml import assertion_from_string
16+
from saml2.saml import EncryptedAssertion
17+
from saml2.samlp import response_from_string
1518
from saml2.s_utils import factory, do_attribute_statement
1619

1720
from py.test import raises
@@ -510,6 +513,6 @@ def test_xmlsec_err():
510513
if __name__ == "__main__":
511514
t = TestSecurity()
512515
t.setup_class()
513-
t.test_verify_1()
516+
t.test_sign_assertion()
514517

515518
#test_xmlsec_err()

0 commit comments

Comments
 (0)