|
4 | 4 | import datetime |
5 | 5 | import base64 |
6 | 6 | import os |
7 | | -import sys |
8 | 7 | import copy |
9 | 8 |
|
10 | 9 | from botocore.vendored import six |
|
16 | 15 |
|
17 | 16 | TOKEN_SKEW = 3 |
18 | 17 | TIME_FORMAT = "%Y%m%dT%H%M%SZ" |
19 | | -PY2 = sys.version[0] == '2' |
| 18 | + |
| 19 | + |
| 20 | +def ensure_text(str_or_bytes, encoding='utf-8'): |
| 21 | + """Ensures an input is a string, decoding if it is bytes. |
| 22 | + """ |
| 23 | + if not isinstance(str_or_bytes, six.text_type): |
| 24 | + return str_or_bytes.decode(encoding) |
| 25 | + return str_or_bytes |
| 26 | + |
| 27 | + |
| 28 | +def ensure_bytes(str_or_bytes, encoding='utf-8', errors='strict'): |
| 29 | + """Ensures an input is bytes, encoding if it is a string. |
| 30 | + """ |
| 31 | + if isinstance(str_or_bytes, six.text_type): |
| 32 | + return str_or_bytes.encode(encoding, errors) |
| 33 | + return str_or_bytes |
20 | 34 |
|
21 | 35 |
|
22 | 36 | class KMSTokenValidator(object): |
@@ -205,12 +219,8 @@ def decrypt_token(self, username, token): |
205 | 219 | version < self.minimum_token_version): |
206 | 220 | raise TokenValidationError('Unacceptable token version.') |
207 | 221 | try: |
208 | | - if PY2: |
209 | | - token_bytes = bytes(token) |
210 | | - else: |
211 | | - token_bytes = bytes(token, 'utf8') |
212 | 222 | token_key = '{0}{1}{2}{3}'.format( |
213 | | - hashlib.sha256(token_bytes).hexdigest(), |
| 223 | + hashlib.sha256(ensure_bytes(token)).hexdigest(), |
214 | 224 | _from, |
215 | 225 | self.to_auth_context, |
216 | 226 | user_type |
@@ -418,7 +428,7 @@ def _cache_token(self, token, not_after): |
418 | 428 | os.makedirs(cachedir) |
419 | 429 | with open(self.token_cache_file, 'w') as f: |
420 | 430 | json.dump({ |
421 | | - 'token': token, |
| 431 | + 'token': ensure_text(token), |
422 | 432 | 'not_after': not_after, |
423 | 433 | 'auth_context': self.auth_context |
424 | 434 | }, f) |
@@ -470,11 +480,7 @@ def get_token(self): |
470 | 480 | Plaintext=payload, |
471 | 481 | EncryptionContext=self.auth_context |
472 | 482 | )['CiphertextBlob'] |
473 | | - if PY2: |
474 | | - token_bytes = bytes(token) |
475 | | - else: |
476 | | - token_bytes = bytes(token, 'utf8') |
477 | | - token = base64.b64encode(token_bytes) |
| 483 | + token = base64.b64encode(ensure_bytes(token)) |
478 | 484 | except (ConnectionError, EndpointConnectionError) as e: |
479 | 485 | logging.exception('Failure connecting to AWS: {}'.format(str(e))) |
480 | 486 | raise ServiceConnectionError() |
|
0 commit comments