Skip to content
This repository was archived by the owner on May 26, 2020. It is now read-only.

Commit 5fc75bf

Browse files
committed
[refresh-token] add refresh token endpoint
typo [refresh-token] Refactor renew token stuff, add jwt_get_user_id_from_payload_handler for custom user_id format [refresh-token] add tests for refresh token feature some more api checks
1 parent c617edb commit 5fc75bf

File tree

7 files changed

+299
-7
lines changed

7 files changed

+299
-7
lines changed

rest_framework_jwt/authentication.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
jwt_decode_handler = api_settings.JWT_DECODE_HANDLER
16+
jwt_get_user_id_from_payload = api_settings.JWT_PAYLOAD_GET_USER_ID_HANDLER
1617

1718

1819
class JSONWebTokenAuthentication(BaseAuthentication):
@@ -62,7 +63,7 @@ def authenticate_credentials(self, payload):
6263
Returns an active user that matches the payload's user id and email.
6364
"""
6465
try:
65-
user_id = payload.get('user_id')
66+
user_id = jwt_get_user_id_from_payload(payload)
6667

6768
if user_id:
6869
user = User.objects.get(pk=user_id, is_active=True)

rest_framework_jwt/serializers.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
from calendar import timegm
2+
from datetime import datetime
3+
import jwt
4+
15
from django.contrib.auth import authenticate, get_user_model
2-
from rest_framework import serializers
6+
from rest_framework import serializers, exceptions
37

48
from rest_framework_jwt.settings import api_settings
59

610

711
jwt_payload_handler = api_settings.JWT_PAYLOAD_HANDLER
812
jwt_encode_handler = api_settings.JWT_ENCODE_HANDLER
13+
jwt_decode_handler = api_settings.JWT_DECODE_HANDLER
14+
jwt_get_user_id_from_payload = api_settings.JWT_PAYLOAD_GET_USER_ID_HANDLER
915

1016

1117
class JSONWebTokenSerializer(serializers.Serializer):
@@ -41,6 +47,13 @@ def validate(self, attrs):
4147

4248
payload = jwt_payload_handler(user)
4349

50+
# Include original issued at time for a brand new token,
51+
# to allow token refresh
52+
if api_settings.JWT_ALLOW_TOKEN_RENEWAL:
53+
payload['orig_iat'] = timegm(
54+
datetime.utcnow().utctimetuple()
55+
)
56+
4457
return {
4558
'token': jwt_encode_handler(payload)
4659
}
@@ -50,3 +63,57 @@ def validate(self, attrs):
5063
else:
5164
msg = 'Must include "username" and "password"'
5265
raise serializers.ValidationError(msg)
66+
67+
68+
class RefreshJSONWebTokenSerializer(serializers.Serializer):
69+
"""
70+
Check an access token
71+
"""
72+
token = serializers.CharField()
73+
74+
def validate(self, attrs):
75+
token = attrs['token']
76+
77+
# Check payload valid (based off of JSONWebTokenAuthentication,
78+
# may want to refactor)
79+
try:
80+
payload = jwt_decode_handler(token)
81+
except jwt.ExpiredSignature:
82+
msg = 'Signature has expired.'
83+
raise serializers.ValidationError(msg)
84+
except jwt.DecodeError:
85+
msg = 'Error decoding signature.'
86+
raise serializers.ValidationError(msg)
87+
88+
# Make sure user exists (may want to refactor this)
89+
User = get_user_model()
90+
try:
91+
user_id = jwt_get_user_id_from_payload(payload)
92+
if user_id:
93+
user = User.objects.get(pk=user_id, is_active=True)
94+
else:
95+
msg = 'Invalid payload'
96+
raise serializers.ValidationError(msg)
97+
except User.DoesNotExist:
98+
raise serializers.ValidationError("User doesn't exist")
99+
100+
# Get and check 'orig_iat'
101+
orig_iat = payload.get('orig_iat')
102+
if orig_iat:
103+
# Verify expiration
104+
expiration_timestamp = (
105+
orig_iat +
106+
int(api_settings.JWT_TOKEN_RENEWAL_LIMIT.total_seconds())
107+
)
108+
now_timestamp = timegm(datetime.utcnow().utctimetuple())
109+
if now_timestamp > expiration_timestamp:
110+
raise serializers.ValidationError("Refresh has expired")
111+
else:
112+
raise serializers.ValidationError("orig_iat field is required")
113+
114+
new_payload = jwt_payload_handler(user)
115+
new_payload['orig_iat'] = orig_iat
116+
117+
return {
118+
'token': jwt_encode_handler(new_payload)
119+
}

rest_framework_jwt/settings.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,26 @@
1616
'JWT_PAYLOAD_HANDLER':
1717
'rest_framework_jwt.utils.jwt_payload_handler',
1818

19+
'JWT_PAYLOAD_GET_USER_ID_HANDLER':
20+
'rest_framework_jwt.utils.jwt_get_user_id_from_payload_handler',
21+
1922
'JWT_SECRET_KEY': settings.SECRET_KEY,
2023
'JWT_ALGORITHM': 'HS256',
2124
'JWT_VERIFY': True,
2225
'JWT_VERIFY_EXPIRATION': True,
2326
'JWT_LEEWAY': 0,
24-
'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=300)
27+
'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=300),
28+
29+
'JWT_ALLOW_TOKEN_RENEWAL': False,
30+
'JWT_TOKEN_RENEWAL_LIMIT': datetime.timedelta(days=7),
2531
}
2632

2733
# List of settings that may be in string import notation.
2834
IMPORT_STRINGS = (
2935
'JWT_ENCODE_HANDLER',
3036
'JWT_DECODE_HANDLER',
3137
'JWT_PAYLOAD_HANDLER',
38+
'JWT_PAYLOAD_GET_USER_ID_HANDLER',
3239
)
3340

3441
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)

rest_framework_jwt/tests/simtime.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
"""
2+
Test utility to simulate changing the current time, namely overriding utcnow()
3+
"""
4+
5+
import datetime
6+
from datetime import datetime as orig_datetime
7+
8+
9+
SIMTIME = None
10+
11+
12+
def set_simtime(simtime):
13+
"""
14+
Argument should be a naive utc datetime
15+
"""
16+
global SIMTIME
17+
SIMTIME = simtime
18+
19+
20+
def clear_simtime():
21+
global SIMTIME
22+
SIMTIME = None
23+
24+
25+
class SimulationDatetimeMeta(type):
26+
"""
27+
Need to override isinstance(<datetime.datetime obj>, SimulationDatetime) to
28+
return True
29+
"""
30+
def __instancecheck__(self, other):
31+
if isinstance(other, datetime.datetime):
32+
return True
33+
34+
35+
class SimulationDatetime(datetime.datetime):
36+
"""
37+
Mock datetime object with patched utcnow() method
38+
"""
39+
40+
@classmethod
41+
def utcnow(cls):
42+
if SIMTIME:
43+
# assert False, SIMTIME
44+
return SIMTIME
45+
46+
return orig_datetime.utcnow()
47+
48+
__metaclass__ = SimulationDatetimeMeta

rest_framework_jwt/tests/test_views.py

Lines changed: 140 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
1+
from calendar import timegm
2+
from datetime import datetime, timedelta
3+
14
from django.test import TestCase
25
from django.conf import settings
36
from django.contrib.auth.models import User
7+
import jwt
48
from rest_framework import status
59
from rest_framework.compat import patterns
610
from rest_framework.test import APIClient
711

12+
import rest_framework_jwt.serializers
813
from rest_framework_jwt import utils
914
from rest_framework_jwt.runtests.models import CustomUser
15+
from rest_framework_jwt.settings import api_settings, DEFAULTS
16+
from rest_framework_jwt.tests import simtime
17+
from rest_framework_jwt.tests.simtime import SimulationDatetime
1018

1119
urlpatterns = patterns(
1220
'',
1321
(r'^auth-token/$', 'rest_framework_jwt.views.obtain_jwt_token'),
22+
(r'^auth-token-refresh/$', 'rest_framework_jwt.views.refresh_jwt_token'),
1423
)
1524

25+
orig_datetime = datetime
26+
1627

17-
class ObtainJSONWebTokenTests(TestCase):
28+
class BaseTestCase(TestCase):
1829
urls = 'rest_framework_jwt.tests.test_views'
1930

2031
def setUp(self):
@@ -29,6 +40,9 @@ def setUp(self):
2940
'password': self.password
3041
}
3142

43+
44+
class ObtainJSONWebTokenTests(BaseTestCase):
45+
3246
def test_jwt_login_json(self):
3347
"""
3448
Ensure JWT login view using JSON POST works.
@@ -145,3 +159,128 @@ def test_jwt_login_json_bad_creds(self):
145159

146160
def tearDown(self):
147161
settings.AUTH_USER_MODEL = self.ORIG_AUTH_USER_MODEL
162+
163+
164+
class RefreshJSONWebTokenTests(BaseTestCase):
165+
urls = 'rest_framework_jwt.tests.test_views'
166+
167+
def setUp(self):
168+
super(RefreshJSONWebTokenTests, self).setUp()
169+
api_settings.JWT_ALLOW_TOKEN_RENEWAL = True
170+
171+
# monkey patch datetime objects in places that use datetime.utcnow()
172+
jwt.datetime = SimulationDatetime
173+
rest_framework_jwt.serializers.datetime = SimulationDatetime
174+
utils.datetime = SimulationDatetime
175+
176+
def get_token(self):
177+
client = APIClient(enforce_csrf_checks=True)
178+
response = client.post('/auth-token/', self.data, format='json')
179+
return response.data['token']
180+
181+
def test_refresh_jwt(self):
182+
"""
183+
Test getting a refreshed token from original token works
184+
"""
185+
client = APIClient(enforce_csrf_checks=True)
186+
187+
# Set simulation time to now
188+
currtime = datetime.utcnow()
189+
simtime.set_simtime(currtime)
190+
191+
orig_token = self.get_token()
192+
orig_token_decoded = utils.jwt_decode_handler(orig_token)
193+
194+
# Make sure 'orig_iat' exists and is the current time
195+
orig_iat = orig_token_decoded['orig_iat']
196+
self.assertEquals(orig_iat, timegm(currtime.utctimetuple()))
197+
198+
# Fast-forward to later time (but before first token expires)
199+
currtime += api_settings.JWT_EXPIRATION_DELTA - timedelta(seconds=30)
200+
simtime.set_simtime(currtime)
201+
202+
# Now try to get a refreshed token
203+
response = client.post('/auth-token-refresh/', {'token': orig_token},
204+
format='json')
205+
self.assertEqual(response.status_code, status.HTTP_200_OK)
206+
207+
new_token = response.data['token']
208+
new_token_decoded = utils.jwt_decode_handler(new_token)
209+
210+
# Make sure 'orig_iat' on the new token is same as origina
211+
self.assertEquals(new_token_decoded['orig_iat'], orig_iat)
212+
self.assertGreater(new_token_decoded['exp'], orig_token_decoded['exp'])
213+
214+
def test_refresh_jwt_fails_with_expired_token(self):
215+
"""
216+
Test that using an expired token to refresh won't work
217+
"""
218+
client = APIClient(enforce_csrf_checks=True)
219+
token = self.get_token()
220+
221+
# Fast-forward to after token expires
222+
after_expire = (
223+
datetime.utcnow() + api_settings.JWT_EXPIRATION_DELTA +
224+
timedelta(seconds=10)
225+
)
226+
simtime.set_simtime(after_expire)
227+
228+
response = client.post('/auth-token-refresh/', {'token': token},
229+
format='json')
230+
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
231+
self.assertRegexpMatches(response.data['non_field_errors'][0],
232+
'Signature has expired')
233+
234+
def test_refresh_jwt_after_renewal_expiration(self):
235+
"""
236+
Test that token can't be refreshed after token renewal limit
237+
"""
238+
# For simpler test, make the RENEWAL_LIMIT just a bit larger than
239+
# EXPIRATION_DELTA
240+
api_settings.JWT_TOKEN_RENEWAL_LIMIT = (
241+
api_settings.JWT_EXPIRATION_DELTA +
242+
api_settings.JWT_EXPIRATION_DELTA / 2
243+
)
244+
245+
client = APIClient(enforce_csrf_checks=True)
246+
247+
initial_time = datetime.utcnow()
248+
249+
token1 = self.get_token()
250+
251+
# Token1 refresh to Token2, just before it expires
252+
currtime = (initial_time + api_settings.JWT_EXPIRATION_DELTA -
253+
timedelta(seconds=5))
254+
255+
simtime.set_simtime(currtime)
256+
257+
response = client.post('/auth-token-refresh/', {'token': token1},
258+
format='json')
259+
token2 = response.data['token']
260+
261+
# Fast-forward to after token renewal expiration
262+
# Token2 hasn't expired yet, but it can't be used to renew anymore!
263+
currtime = (initial_time + api_settings.JWT_TOKEN_RENEWAL_LIMIT +
264+
timedelta(minutes=1))
265+
simtime.set_simtime(currtime)
266+
267+
response = client.post('/auth-token-refresh/', {'token': token2},
268+
format='json')
269+
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
270+
self.assertEqual(response.data['non_field_errors'][0],
271+
'Refresh has expired')
272+
273+
def tearDown(self):
274+
# Restore original settings
275+
api_settings.JWT_ALLOW_TOKEN_RENEWAL = \
276+
DEFAULTS['JWT_ALLOW_TOKEN_RENEWAL']
277+
278+
api_settings.JWT_TOKEN_RENEWAL_LIMIT = \
279+
DEFAULTS['JWT_TOKEN_RENEWAL_LIMIT']
280+
281+
# Undo datetime monkeypatching
282+
jwt.datetime = orig_datetime
283+
rest_framework_jwt.serializers.datetime = orig_datetime
284+
utils.datetime = orig_datetime
285+
286+
simtime.clear_simtime()

rest_framework_jwt/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import datetime
1+
from datetime import datetime
22
import jwt
33

44
from rest_framework_jwt.settings import api_settings
@@ -9,10 +9,18 @@ def jwt_payload_handler(user):
99
'user_id': user.pk,
1010
'email': user.email,
1111
'username': user.get_username(),
12-
'exp': datetime.datetime.utcnow() + api_settings.JWT_EXPIRATION_DELTA
12+
'exp': datetime.utcnow() + api_settings.JWT_EXPIRATION_DELTA
1313
}
1414

1515

16+
def jwt_get_user_id_from_payload_handler(payload):
17+
"""
18+
Override this function if user_id is formatted differently in payload
19+
"""
20+
user_id = payload.get('user_id')
21+
return user_id
22+
23+
1624
def jwt_encode_handler(payload):
1725
return jwt.encode(
1826
payload,

0 commit comments

Comments
 (0)