Skip to content

Commit 1fde97d

Browse files
committed
Fix bytes/strings logical issues
Hashes, and other calls, require bytes or strings in python3 where they were different in python 2.x.
1 parent 2922132 commit 1fde97d

File tree

6 files changed

+49
-9
lines changed

6 files changed

+49
-9
lines changed

src/saml2/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import re
99
import logging
1010
import logging.handlers
11+
import six
1112

1213
from importlib import import_module
1314

@@ -296,7 +297,7 @@ def load_complex(self, cnf, typ="", metadata_construction=False):
296297

297298
def unicode_convert(self, item):
298299
try:
299-
return unicode(item, "utf-8")
300+
return six.text_type(item, "utf-8")
300301
except TypeError:
301302
_uc = self.unicode_convert
302303
if isinstance(item, dict):

src/saml2/eptid.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import shelve
99

1010
import logging
11+
import six
1112

1213
logger = logging.getLogger(__name__)
1314

@@ -21,13 +22,23 @@ def make(self, idp, sp, args):
2122
md5 = hashlib.md5()
2223
for arg in args:
2324
md5.update(arg.encode("utf-8"))
24-
md5.update(sp)
25-
md5.update(self.secret)
25+
if isinstance(sp, six.binary_type):
26+
md5.update(sp)
27+
else:
28+
md5.update(sp.encode('utf-8'))
29+
if isinstance(self.secret, six.binary_type):
30+
md5.update(self.secret)
31+
else:
32+
md5.update(self.secret.encode('utf-8'))
2633
md5.digest()
2734
hashval = md5.hexdigest()
35+
if isinstance(hashval, six.binary_type):
36+
hashval = hashval.decode('ascii')
2837
return "!".join([idp, sp, hashval])
2938

3039
def __getitem__(self, key):
40+
if six.PY3 and isinstance(key, six.binary_type):
41+
key = key.decode('utf-8')
3142
return self._db[key]
3243

3344
def __setitem__(self, key, value):

src/saml2/ident.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from six.moves.urllib.parse import quote
88
from six.moves.urllib.parse import unquote
99
from saml2 import SAMLError
10-
from saml2.s_utils import rndstr
10+
from saml2.s_utils import rndbytes
1111
from saml2.s_utils import PolicyError
1212
from saml2.saml import NameID
1313
from saml2.saml import NAMEID_FORMAT_PERSISTENT
@@ -46,6 +46,16 @@ class that is used.
4646
return ",".join(_res)
4747

4848

49+
def code_binary(item):
50+
"""
51+
Return a binary 'code' suitable for hashing.
52+
"""
53+
code_str = code(item)
54+
if isinstance(code_str, six.string_types):
55+
return code_str.encode('utf-8')
56+
return code_str
57+
58+
4959
def decode(txt):
5060
"""Turns a coded string by code() into a NameID class instance.
5161
@@ -75,11 +85,17 @@ def __init__(self, db, domain="", name_qualifier=""):
7585
self.name_qualifier = name_qualifier
7686

7787
def _create_id(self, nformat, name_qualifier="", sp_name_qualifier=""):
78-
_id = sha256(rndstr(32))
88+
_id = sha256(rndbytes(32))
89+
if not isinstance(nformat, six.binary_type):
90+
nformat = nformat.encode('utf-8')
7991
_id.update(nformat)
8092
if name_qualifier:
93+
if not isinstance(name_qualifier, six.binary_type):
94+
name_qualifier = name_qualifier.encode('utf-8')
8195
_id.update(name_qualifier)
8296
if sp_name_qualifier:
97+
if not isinstance(sp_name_qualifier, six.binary_type):
98+
sp_name_qualifier = sp_name_qualifier.encode('utf-8')
8399
_id.update(sp_name_qualifier)
84100
return _id.hexdigest()
85101

src/saml2/mdstore.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import sys
55
import json
6+
import six
67

78
from hashlib import sha1
89
from os.path import isfile, join
@@ -487,6 +488,8 @@ def construct_source_id(self):
487488
try:
488489
for srv in ent[desc]:
489490
if "artifact_resolution_service" in srv:
491+
if isinstance(eid, six.string_types):
492+
eid = eid.encode('utf-8')
490493
s = sha1(eid)
491494
res[s.digest()] = ent
492495
except KeyError:

src/saml2/mongo_store.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from saml2.mdstore import InMemoryMetaData
1010
from saml2.s_utils import PolicyError
1111

12-
from saml2.ident import code, IdentDB, Unknown
12+
from saml2.ident import code_binary, IdentDB, Unknown
1313
from saml2.mdie import to_dict, from_dict
1414

1515
from saml2 import md
@@ -59,7 +59,7 @@ def __init__(self, database="", collection="assertion", **kwargs):
5959

6060
def store_assertion(self, assertion, to_sign):
6161
name_id = assertion.subject.name_id
62-
nkey = sha1(code(name_id)).hexdigest()
62+
nkey = sha1(code_binary(name_id)).hexdigest()
6363

6464
doc = {
6565
"name_id_key": nkey,
@@ -94,7 +94,7 @@ def get_assertions_by_subject(self, name_id=None, session_index=None,
9494
:return:
9595
"""
9696
result = []
97-
key = sha1(code(name_id)).hexdigest()
97+
key = sha1(code_binary(name_id)).hexdigest()
9898
for item in self.assertion.find({"name_id_key": key}):
9999
assertion = from_dict(item["assertion"], ONTS, True)
100100
if session_index or requested_context:
@@ -114,7 +114,7 @@ def get_assertions_by_subject(self, name_id=None, session_index=None,
114114

115115
def remove_authn_statements(self, name_id):
116116
logger.debug("remove authn about: %s" % name_id)
117-
key = sha1(code(name_id)).hexdigest()
117+
key = sha1(code_binary(name_id)).hexdigest()
118118
for item in self.assertion.find({"name_id_key": key}):
119119
self.assertion.remove(item["_id"])
120120

src/saml2/s_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,15 @@ def rndstr(size=16, alphabet=""):
167167
alphabet = string.ascii_letters[0:52] + string.digits
168168
return type(alphabet)().join(rng.choice(alphabet) for _ in range(size))
169169

170+
def rndbytes(size=16, alphabet=""):
171+
"""
172+
Returns rndstr always as a binary type
173+
"""
174+
x = rndstr(size, alphabet)
175+
if isinstance(x, six.string_types):
176+
return x.encode('utf-8')
177+
return x
178+
170179

171180
def sid():
172181
"""creates an unique SID for each session.

0 commit comments

Comments
 (0)