Skip to content

Commit e847c46

Browse files
garrettheelryan-lane
authored andcommitted
Fix a bug with bytes in py3 (#13)
* Fix a bug with bytes in py3 * nits * unused constant * Fix lint
1 parent 7e60812 commit e847c46

File tree

3 files changed

+22
-16
lines changed

3 files changed

+22
-16
lines changed

kmsauth/__init__.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import datetime
55
import base64
66
import os
7-
import sys
87
import copy
98

109
from botocore.vendored import six
@@ -16,7 +15,22 @@
1615

1716
TOKEN_SKEW = 3
1817
TIME_FORMAT = "%Y%m%dT%H%M%SZ"
19-
PY2 = sys.version[0] == '2'
18+
19+
20+
def ensure_text(str_or_bytes, encoding='utf-8'):
21+
"""Ensures an input is a string, decoding if it is bytes.
22+
"""
23+
if not isinstance(str_or_bytes, six.text_type):
24+
return str_or_bytes.decode(encoding)
25+
return str_or_bytes
26+
27+
28+
def ensure_bytes(str_or_bytes, encoding='utf-8', errors='strict'):
29+
"""Ensures an input is bytes, encoding if it is a string.
30+
"""
31+
if isinstance(str_or_bytes, six.text_type):
32+
return str_or_bytes.encode(encoding, errors)
33+
return str_or_bytes
2034

2135

2236
class KMSTokenValidator(object):
@@ -205,12 +219,8 @@ def decrypt_token(self, username, token):
205219
version < self.minimum_token_version):
206220
raise TokenValidationError('Unacceptable token version.')
207221
try:
208-
if PY2:
209-
token_bytes = bytes(token)
210-
else:
211-
token_bytes = bytes(token, 'utf8')
212222
token_key = '{0}{1}{2}{3}'.format(
213-
hashlib.sha256(token_bytes).hexdigest(),
223+
hashlib.sha256(ensure_bytes(token)).hexdigest(),
214224
_from,
215225
self.to_auth_context,
216226
user_type
@@ -418,7 +428,7 @@ def _cache_token(self, token, not_after):
418428
os.makedirs(cachedir)
419429
with open(self.token_cache_file, 'w') as f:
420430
json.dump({
421-
'token': token,
431+
'token': ensure_text(token),
422432
'not_after': not_after,
423433
'auth_context': self.auth_context
424434
}, f)
@@ -470,11 +480,7 @@ def get_token(self):
470480
Plaintext=payload,
471481
EncryptionContext=self.auth_context
472482
)['CiphertextBlob']
473-
if PY2:
474-
token_bytes = bytes(token)
475-
else:
476-
token_bytes = bytes(token, 'utf8')
477-
token = base64.b64encode(token_bytes)
483+
token = base64.b64encode(ensure_bytes(token))
478484
except (ConnectionError, EndpointConnectionError) as e:
479485
logging.exception('Failure connecting to AWS: {}'.format(str(e)))
480486
raise ServiceConnectionError()

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ flake8==2.3.0
1212
# Measures code coverage and emits coverage reports
1313
# Licence: BSD
1414
# Upstream url: https://pypi.python.org/pypi/coverage
15-
coverage==3.7.1
15+
coverage==4.4.2
1616

1717
# tool to check your Python code against some of the style conventions
1818
# License: Expat License
@@ -22,7 +22,7 @@ pep8==1.5.7
2222
# nose makes testing easier
2323
# License: GNU Library or Lesser General Public License (LGPL)
2424
# Upstream url: http://readthedocs.org/docs/nose
25-
nose==1.3.3
25+
nose==1.3.7
2626

2727
# Mocking and Patching Library for Testing
2828
# License: BSD

tests/unit/kmsauth/kmsauth_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def test_get_username(self):
430430
def test_get_token(self, boto_mock):
431431
kms_mock = MagicMock()
432432
kms_mock.encrypt = MagicMock(
433-
return_value={'CiphertextBlob': 'encrypted'}
433+
return_value={'CiphertextBlob': b'encrypted'}
434434
)
435435
boto_mock.return_value = kms_mock
436436
client = kmsauth.KMSTokenGenerator(

0 commit comments

Comments
 (0)