Skip to content

Commit 7b1e415

Browse files
authored
Support for passing RSA private key to JWTAuth (#222)
This allows a `JWTAuth` object to be initialized, without requiring that the key be available to the process on disk. Fixes #177.
1 parent 5df3f39 commit 7b1e415

File tree

5 files changed

+181
-23
lines changed

5 files changed

+181
-23
lines changed

HISTORY.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ Release History
7171
the ``auto_session_renewal`` functionality of ``BoxSession``, this means
7272
that authentication for ``JWTAuth`` objects can be done completely
7373
automatically, at the time of first API call.
74+
- The constructor now supports passing the RSA private key in two different
75+
ways: by file system path (existing functionality), or by passing the key
76+
data directly (new functionality). The ``rsa_private_key_file_sys_path``
77+
parameter is now optional, but it is required to pass exactly one of
78+
``rsa_private_key_file_sys_path`` or ``rsa_private_key_data``.
7479
- Document that the ``enterprise_id`` argument to ``JWTAuth`` is allowed to
7580
be ``None``.
7681
- ``authenticate_instance()`` now accepts an ``enterprise`` argument, which

boxsdk/auth/jwt_auth.py

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,13 @@
88

99
from cryptography.hazmat.backends import default_backend
1010
from cryptography.hazmat.primitives import serialization
11+
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
1112
import jwt
12-
from six import string_types, text_type
13+
from six import binary_type, string_types, raise_from, text_type
1314

1415
from .oauth2 import OAuth2
1516
from ..object.user import User
16-
from ..util.compat import total_seconds
17+
from ..util.compat import NoneType, total_seconds
1718

1819

1920
class JWTAuth(OAuth2):
@@ -28,7 +29,7 @@ def __init__(
2829
client_secret,
2930
enterprise_id,
3031
jwt_key_id,
31-
rsa_private_key_file_sys_path,
32+
rsa_private_key_file_sys_path=None,
3233
rsa_private_key_passphrase=None,
3334
user=None,
3435
store_tokens=None,
@@ -37,9 +38,13 @@ def __init__(
3738
access_token=None,
3839
network_layer=None,
3940
jwt_algorithm='RS256',
41+
rsa_private_key_data=None,
4042
):
4143
"""Extends baseclass method.
4244
45+
Must pass exactly one of either `rsa_private_key_file_sys_path` or
46+
`rsa_private_key_data`.
47+
4348
If both `enterprise_id` and `user` are non-`None`, the `user` takes
4449
precedence when `refresh()` is called. This can be overruled with a
4550
call to `authenticate_instance()`.
@@ -68,13 +73,13 @@ def __init__(
6873
:type jwt_key_id:
6974
`unicode`
7075
:param rsa_private_key_file_sys_path:
71-
Path to an RSA private key file, used for signing the JWT assertion.
76+
(optional) Path to an RSA private key file, used for signing the JWT assertion.
7277
:type rsa_private_key_file_sys_path:
7378
`unicode`
7479
:param rsa_private_key_passphrase:
7580
Passphrase used to unlock the private key. Do not pass a unicode string - this must be bytes.
7681
:type rsa_private_key_passphrase:
77-
`str` or None
82+
`bytes` or None
7883
:param user:
7984
(optional) The user to authenticate, expressed as a Box User ID or
8085
as a :class:`User` instance.
@@ -120,8 +125,20 @@ def __init__(
120125
Which algorithm to use for signing the JWT assertion. Must be one of 'RS256', 'RS384', 'RS512'.
121126
:type jwt_algorithm:
122127
`unicode`
128+
:param rsa_private_key_data:
129+
(optional) Contents of RSA private key, used for signing the JWT assertion. Do not pass a
130+
unicode string. Can pass a byte string, or a file-like object that returns bytes, or an
131+
already-loaded `RSAPrivateKey` object.
132+
:type rsa_private_key_data: `bytes` or :class:`io.IOBase` or :class:`RSAPrivateKey`
123133
"""
124134
user_id = self._normalize_user_id(user)
135+
rsa_private_key = self._normalize_rsa_private_key(
136+
file_sys_path=rsa_private_key_file_sys_path,
137+
data=rsa_private_key_data,
138+
passphrase=rsa_private_key_passphrase,
139+
)
140+
del rsa_private_key_data
141+
del rsa_private_key_file_sys_path
125142
super(JWTAuth, self).__init__(
126143
client_id,
127144
client_secret,
@@ -132,12 +149,7 @@ def __init__(
132149
refresh_token=None,
133150
network_layer=network_layer,
134151
)
135-
with open(rsa_private_key_file_sys_path, 'rb') as key_file:
136-
self._rsa_private_key = serialization.load_pem_private_key(
137-
key_file.read(),
138-
password=rsa_private_key_passphrase,
139-
backend=default_backend(),
140-
)
152+
self._rsa_private_key = rsa_private_key
141153
self._enterprise_id = enterprise_id
142154
self._jwt_algorithm = jwt_algorithm
143155
self._jwt_key_id = jwt_key_id
@@ -295,3 +307,54 @@ def _refresh(self, access_token):
295307
else:
296308
new_access_token = self.authenticate_user()
297309
return new_access_token, None
310+
311+
@classmethod
312+
def _normalize_rsa_private_key(cls, file_sys_path, data, passphrase=None):
313+
if len(list(filter(None, [file_sys_path, data]))) != 1:
314+
raise TypeError("must pass exactly one of either rsa_private_key_file_sys_path or rsa_private_key_data")
315+
if file_sys_path:
316+
with open(file_sys_path, 'rb') as key_file:
317+
data = key_file.read()
318+
if hasattr(data, 'read') and callable(data.read):
319+
data = data.read()
320+
if isinstance(data, text_type):
321+
try:
322+
data = data.encode('ascii')
323+
except UnicodeError:
324+
raise_from(
325+
TypeError("rsa_private_key_data must contain binary data (bytes/str), not a text/unicode string"),
326+
None,
327+
)
328+
if isinstance(data, binary_type):
329+
passphrase = cls._normalize_rsa_private_key_passphrase(passphrase)
330+
return serialization.load_pem_private_key(
331+
data,
332+
password=passphrase,
333+
backend=default_backend(),
334+
)
335+
if isinstance(data, RSAPrivateKey):
336+
return data
337+
raise TypeError(
338+
'rsa_private_key_data must be binary data (bytes/str), '
339+
'a file-like object with a read() method, '
340+
'or an instance of RSAPrivateKey, '
341+
'but got {0!r}'
342+
.format(data.__class__.__name__)
343+
)
344+
345+
@staticmethod
346+
def _normalize_rsa_private_key_passphrase(passphrase):
347+
if isinstance(passphrase, text_type):
348+
try:
349+
return passphrase.encode('ascii')
350+
except UnicodeError:
351+
raise_from(
352+
TypeError("rsa_private_key_passphrase must contain binary data (bytes/str), not a text/unicode string"),
353+
None,
354+
)
355+
if not isinstance(passphrase, (binary_type, NoneType)):
356+
raise TypeError(
357+
"rsa_private_key_passphrase must contain binary data (bytes/str), got {0!r}"
358+
.format(passphrase.__class__.__name__)
359+
)
360+
return passphrase

boxsdk/util/compat.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
import six
88

99

10+
NoneType = type(None)
11+
12+
1013
if not hasattr(timedelta, 'total_seconds'):
1114
def total_seconds(delta):
1215
"""

test/unit/auth/test_jwt_auth.py

Lines changed: 97 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44

55
from contextlib import contextmanager
66
from datetime import datetime, timedelta
7+
import io
78
from itertools import product
89
import json
910
import random
1011
import string
1112

1213
from cryptography.hazmat.backends import default_backend
14+
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, generate_private_key as generate_rsa_private_key
15+
from cryptography.hazmat.primitives import serialization
1316
from mock import Mock, mock_open, patch, sentinel
1417
import pytest
15-
from six import string_types, text_type
18+
from six import binary_type, string_types, text_type
1619

1720
from boxsdk.auth.jwt_auth import JWTAuth
1821
from boxsdk.config import API
@@ -35,11 +38,26 @@ def jwt_key_id():
3538
return 'jwt_key_id_1'
3639

3740

41+
@pytest.fixture(scope='module')
42+
def rsa_private_key_object():
43+
return generate_rsa_private_key(public_exponent=65537, key_size=4096, backend=default_backend())
44+
45+
3846
@pytest.fixture(params=(None, b'strong_password'))
3947
def rsa_passphrase(request):
4048
return request.param
4149

4250

51+
@pytest.fixture
52+
def rsa_private_key_bytes(rsa_private_key_object, rsa_passphrase):
53+
encryption = serialization.BestAvailableEncryption(rsa_passphrase) if rsa_passphrase else serialization.NoEncryption()
54+
return rsa_private_key_object.private_bytes(
55+
encoding=serialization.Encoding.PEM,
56+
format=serialization.PrivateFormat.PKCS8,
57+
encryption_algorithm=encryption,
58+
)
59+
60+
4361
@pytest.fixture(scope='function')
4462
def successful_token_response(successful_token_mock, successful_token_json_response):
4563
# pylint:disable=redefined-outer-name
@@ -52,8 +70,76 @@ def successful_token_response(successful_token_mock, successful_token_json_respo
5270
return successful_token_mock
5371

5472

73+
@pytest.mark.parametrize(('key_file', 'key_data'), [(None, None), ('fake sys path', 'fake key data')])
74+
@pytest.mark.parametrize('rsa_passphrase', [None])
75+
def test_jwt_auth_init_raises_type_error_unless_exactly_one_of_rsa_private_key_file_or_data_is_given(key_file, key_data, rsa_private_key_bytes):
76+
kwargs = dict(
77+
rsa_private_key_data=rsa_private_key_bytes,
78+
client_id=None,
79+
client_secret=None,
80+
jwt_key_id=None,
81+
enterprise_id=None,
82+
)
83+
JWTAuth(**kwargs)
84+
kwargs.update(rsa_private_key_file_sys_path=key_file, rsa_private_key_data=key_data)
85+
with pytest.raises(TypeError):
86+
JWTAuth(**kwargs)
87+
88+
89+
@pytest.mark.parametrize('key_data', [object(), u'ƒøø'])
90+
@pytest.mark.parametrize('rsa_passphrase', [None])
91+
def test_jwt_auth_init_raises_type_error_if_rsa_private_key_data_has_unexpected_type(key_data, rsa_private_key_bytes):
92+
kwargs = dict(
93+
rsa_private_key_data=rsa_private_key_bytes,
94+
client_id=None,
95+
client_secret=None,
96+
jwt_key_id=None,
97+
enterprise_id=None,
98+
)
99+
JWTAuth(**kwargs)
100+
kwargs.update(rsa_private_key_data=key_data)
101+
with pytest.raises(TypeError):
102+
JWTAuth(**kwargs)
103+
104+
105+
@pytest.mark.parametrize('rsa_private_key_data_type', [io.BytesIO, text_type, binary_type, RSAPrivateKey])
106+
def test_jwt_auth_init_accepts_rsa_private_key_data(rsa_private_key_bytes, rsa_passphrase, rsa_private_key_data_type):
107+
if rsa_private_key_data_type is text_type:
108+
rsa_private_key_data = text_type(rsa_private_key_bytes.decode('ascii'))
109+
elif rsa_private_key_data_type is RSAPrivateKey:
110+
rsa_private_key_data = serialization.load_pem_private_key(
111+
rsa_private_key_bytes,
112+
password=rsa_passphrase,
113+
backend=default_backend(),
114+
)
115+
else:
116+
rsa_private_key_data = rsa_private_key_data_type(rsa_private_key_bytes)
117+
JWTAuth(
118+
rsa_private_key_data=rsa_private_key_data,
119+
rsa_private_key_passphrase=rsa_passphrase,
120+
client_id=None,
121+
client_secret=None,
122+
jwt_key_id=None,
123+
enterprise_id=None,
124+
)
125+
126+
127+
@pytest.fixture(params=[False, True])
128+
def pass_private_key_by_path(request):
129+
"""For jwt_auth_init_mocks, whether to pass the private key via sys_path (True) or pass the data directly (False)."""
130+
return request.param
131+
132+
55133
@pytest.fixture
56-
def jwt_auth_init_mocks(mock_network_layer, successful_token_response, jwt_algorithm, jwt_key_id, rsa_passphrase):
134+
def jwt_auth_init_mocks(
135+
mock_network_layer,
136+
successful_token_response,
137+
jwt_algorithm,
138+
jwt_key_id,
139+
rsa_passphrase,
140+
rsa_private_key_bytes,
141+
pass_private_key_by_path,
142+
):
57143
# pylint:disable=redefined-outer-name
58144

59145
@contextmanager
@@ -70,15 +156,14 @@ def _jwt_auth_init_mocks(**kwargs):
70156
'box_device_id': '0',
71157
'box_device_name': 'my_awesome_device',
72158
}
73-
74159
mock_network_layer.request.return_value = successful_token_response
75-
key_file_read_data = b'key_file_read_data'
76-
with patch('boxsdk.auth.jwt_auth.open', mock_open(read_data=key_file_read_data), create=True) as jwt_auth_open:
160+
with patch('boxsdk.auth.jwt_auth.open', mock_open(read_data=rsa_private_key_bytes), create=True) as jwt_auth_open:
77161
with patch('cryptography.hazmat.primitives.serialization.load_pem_private_key') as load_pem_private_key:
78162
oauth = JWTAuth(
79163
client_id=fake_client_id,
80164
client_secret=fake_client_secret,
81-
rsa_private_key_file_sys_path=sentinel.rsa_path,
165+
rsa_private_key_file_sys_path=(sentinel.rsa_path if pass_private_key_by_path else None),
166+
rsa_private_key_data=(None if pass_private_key_by_path else rsa_private_key_bytes),
82167
rsa_private_key_passphrase=rsa_passphrase,
83168
network_layer=mock_network_layer,
84169
box_device_name='my_awesome_device',
@@ -87,11 +172,13 @@ def _jwt_auth_init_mocks(**kwargs):
87172
enterprise_id=kwargs.pop('enterprise_id', None),
88173
**kwargs
89174
)
90-
91-
jwt_auth_open.assert_called_once_with(sentinel.rsa_path, 'rb')
92-
jwt_auth_open.return_value.read.assert_called_once_with() # pylint:disable=no-member
175+
if pass_private_key_by_path:
176+
jwt_auth_open.assert_called_once_with(sentinel.rsa_path, 'rb')
177+
jwt_auth_open.return_value.read.assert_called_once_with() # pylint:disable=no-member
178+
else:
179+
jwt_auth_open.assert_not_called()
93180
load_pem_private_key.assert_called_once_with(
94-
key_file_read_data,
181+
rsa_private_key_bytes,
95182
password=rsa_passphrase,
96183
backend=default_backend(),
97184
)

test/unit/util/test_api_call_decorator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ def func():
6666

6767
def test_api_call_decorated_method_must_be_a_cloneable_method():
6868

69-
class Cls(object):
69+
class NonCloneable(object):
7070
@api_call
7171
def func(self):
7272
pass
7373

74-
obj = Cls()
74+
obj = NonCloneable()
7575
with pytest.raises(TypeError):
7676
obj.func()
7777

0 commit comments

Comments
 (0)