Skip to content
This repository was archived by the owner on May 26, 2020. It is now read-only.

Commit 5ec429a

Browse files
committed
Merge pull request #23 from doordash/refresh-token
Add refresh token feature
2 parents e329e6a + 27ca012 commit 5ec429a

File tree

7 files changed

+263
-8
lines changed

7 files changed

+263
-8
lines changed

README.md

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,22 @@ Now in order to access protected api urls you must include the `Authorization: J
6161
$ curl -H "Authorization: JWT <your_token>" http://localhost:8000/protected-url/
6262
```
6363

64+
## Refresh Token
65+
If `JWT_ALLOW_REFRESH` is True, issued tokens can be "refreshed" to obtain a new brand token with renewed expiration time. Add a URL pattern like this:
66+
```python
67+
url(r'^api-token-refresh/', 'rest_framework_jwt.views.refresh_jwt_token'),
68+
```
69+
70+
Pass in an existing token to the refresh endpoint as follows: `{"token": EXISTING_TOKEN}`. Note that only non-expired tokens will work. The JSON response looks the same as the normal obtain token endpoint `{"token": NEW_TOKEN}`.
71+
72+
```bash
73+
$ curl -X POST -H "Content-Type: application/json" -d '{"token":"<EXISTING_TOKEN>}' http://localhost:8000/api-token-refresh/
74+
```
75+
76+
Refresh with tokens can be repeated (token1 -> token2 -> token3), but this chain of token stores the time that the original token (obtained with username/password credentials), as `orig_iat`. You can only keep refreshing tokens up to `JWT_REFRESH_EXPIRATION_DELTA`.
77+
78+
A typical use case might be a web app where you'd like to keep the user "logged in" the site without having to re-enter their password, or get kicked out by surprise before their token expired. Imagine they had a 1-hour token and are just at the last minute while they're still doing something. With mobile you could perhaps store the username/password to get a new token, but this is not a great idea in a browser. Each time the user loads the page, you can check if there is an existing non-expired token and if it's close to being expired, refresh it to extend their session. In other words, if a user is actively using your site, they can keep their "session" alive.
79+
6480
## Additional Settings
6581
There are some additional settings that you can override similar to how you'd do it with Django REST framework itself. Here are all the available defaults.
6682

@@ -75,12 +91,18 @@ JWT_AUTH = {
7591
'JWT_PAYLOAD_HANDLER':
7692
'rest_framework_jwt.utils.jwt_payload_handler',
7793

94+
'JWT_PAYLOAD_GET_USER_ID_HANDLER':
95+
'rest_framework_jwt.utils.jwt_get_user_id_from_payload_handler',
96+
7897
'JWT_SECRET_KEY': settings.SECRET_KEY,
7998
'JWT_ALGORITHM': 'HS256',
8099
'JWT_VERIFY': True,
81100
'JWT_VERIFY_EXPIRATION': True,
82101
'JWT_LEEWAY': 0,
83-
'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=300)
102+
'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=300),
103+
104+
'JWT_ALLOW_REFRESH': False,
105+
'JWT_REFRESH_EXPIRATION_DELTA': datetime.timedelta(days=7),
84106
}
85107
```
86108
This packages uses the JSON Web Token Python implementation, [PyJWT](https://github.com/progrium/pyjwt) and allows to modify some of it's available options.
@@ -126,8 +148,21 @@ Default is `True`.
126148
127149
Default is `0` seconds.
128150

129-
130151
### JWT_EXPIRATION_DELTA
131152
This is an instance of Python's `datetime.timedelta`. This will be added to `datetime.utcnow()` to set the expiration time.
132153

133154
Default is `datetime.timedelta(seconds=300)`(5 minutes).
155+
156+
### JWT_ALLOW_REFRESH
157+
Enable token refresh functionality. Token issued from `rest_framework_jwt.views.obtain_jwt_token` will have an `orig_iat` field. Default is `False`
158+
159+
### JWT_REFRESH_EXPIRATION_DELTA
160+
Limit on token refresh, is a `datetime.timedelta` instance. This is how much time after the original token that future tokens can be refreshed from.
161+
162+
Default is `datetime.timedelta(days=7)` (7 days).
163+
164+
### JWT_PAYLOAD_HANDLER
165+
Specify a custom function to generate the token payload
166+
167+
### JWT_PAYLOAD_GET_USER_ID_HANDLER
168+
If you store `user_id` differently than the default payload handler does, implement this function to fetch `user_id` from the payload.

rest_framework_jwt/authentication.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414

1515
jwt_decode_handler = api_settings.JWT_DECODE_HANDLER
16+
jwt_get_user_id_from_payload = api_settings.JWT_PAYLOAD_GET_USER_ID_HANDLER
1617

1718

1819
class JSONWebTokenAuthentication(BaseAuthentication):
@@ -62,7 +63,7 @@ def authenticate_credentials(self, payload):
6263
Returns an active user that matches the payload's user id and email.
6364
"""
6465
try:
65-
user_id = payload.get('user_id')
66+
user_id = jwt_get_user_id_from_payload(payload)
6667

6768
if user_id:
6869
user = User.objects.get(pk=user_id, is_active=True)

rest_framework_jwt/serializers.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from calendar import timegm
2+
from datetime import datetime, timedelta
3+
import jwt
4+
15
from django.contrib.auth import authenticate, get_user_model
26
from rest_framework import serializers
37

@@ -6,6 +10,8 @@
610

711
jwt_payload_handler = api_settings.JWT_PAYLOAD_HANDLER
812
jwt_encode_handler = api_settings.JWT_ENCODE_HANDLER
13+
jwt_decode_handler = api_settings.JWT_DECODE_HANDLER
14+
jwt_get_user_id_from_payload = api_settings.JWT_PAYLOAD_GET_USER_ID_HANDLER
915

1016

1117
class JSONWebTokenSerializer(serializers.Serializer):
@@ -41,6 +47,13 @@ def validate(self, attrs):
4147

4248
payload = jwt_payload_handler(user)
4349

50+
# Include original issued at time for a brand new token,
51+
# to allow token refresh
52+
if api_settings.JWT_ALLOW_REFRESH:
53+
payload['orig_iat'] = timegm(
54+
datetime.utcnow().utctimetuple()
55+
)
56+
4457
return {
4558
'token': jwt_encode_handler(payload)
4659
}
@@ -50,3 +63,61 @@ def validate(self, attrs):
5063
else:
5164
msg = 'Must include "username" and "password"'
5265
raise serializers.ValidationError(msg)
66+
67+
68+
class RefreshJSONWebTokenSerializer(serializers.Serializer):
69+
"""
70+
Check an access token
71+
"""
72+
token = serializers.CharField()
73+
74+
def validate(self, attrs):
75+
token = attrs['token']
76+
77+
# Check payload valid (based off of JSONWebTokenAuthentication,
78+
# may want to refactor)
79+
try:
80+
payload = jwt_decode_handler(token)
81+
except jwt.ExpiredSignature:
82+
msg = 'Signature has expired.'
83+
raise serializers.ValidationError(msg)
84+
except jwt.DecodeError:
85+
msg = 'Error decoding signature.'
86+
raise serializers.ValidationError(msg)
87+
88+
# Make sure user exists (may want to refactor this)
89+
User = get_user_model()
90+
try:
91+
user_id = jwt_get_user_id_from_payload(payload)
92+
if user_id:
93+
user = User.objects.get(pk=user_id, is_active=True)
94+
else:
95+
msg = 'Invalid payload'
96+
raise serializers.ValidationError(msg)
97+
except User.DoesNotExist:
98+
raise serializers.ValidationError("User doesn't exist")
99+
100+
# Get and check 'orig_iat'
101+
orig_iat = payload.get('orig_iat')
102+
if orig_iat:
103+
# Verify expiration
104+
refresh_limit = api_settings.JWT_REFRESH_EXPIRATION_DELTA
105+
if isinstance(refresh_limit, timedelta):
106+
refresh_limit = (refresh_limit.days * 24 * 3600 +
107+
refresh_limit.seconds)
108+
expiration_timestamp = (
109+
orig_iat +
110+
int(refresh_limit)
111+
)
112+
now_timestamp = timegm(datetime.utcnow().utctimetuple())
113+
if now_timestamp > expiration_timestamp:
114+
raise serializers.ValidationError("Refresh has expired")
115+
else:
116+
raise serializers.ValidationError("orig_iat field is required")
117+
118+
new_payload = jwt_payload_handler(user)
119+
new_payload['orig_iat'] = orig_iat
120+
121+
return {
122+
'token': jwt_encode_handler(new_payload)
123+
}

rest_framework_jwt/settings.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,26 @@
1616
'JWT_PAYLOAD_HANDLER':
1717
'rest_framework_jwt.utils.jwt_payload_handler',
1818

19+
'JWT_PAYLOAD_GET_USER_ID_HANDLER':
20+
'rest_framework_jwt.utils.jwt_get_user_id_from_payload_handler',
21+
1922
'JWT_SECRET_KEY': settings.SECRET_KEY,
2023
'JWT_ALGORITHM': 'HS256',
2124
'JWT_VERIFY': True,
2225
'JWT_VERIFY_EXPIRATION': True,
2326
'JWT_LEEWAY': 0,
24-
'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=300)
27+
'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=300),
28+
29+
'JWT_ALLOW_REFRESH': False,
30+
'JWT_REFRESH_EXPIRATION_DELTA': datetime.timedelta(days=7),
2531
}
2632

2733
# List of settings that may be in string import notation.
2834
IMPORT_STRINGS = (
2935
'JWT_ENCODE_HANDLER',
3036
'JWT_DECODE_HANDLER',
3137
'JWT_PAYLOAD_HANDLER',
38+
'JWT_PAYLOAD_GET_USER_ID_HANDLER',
3239
)
3340

3441
api_settings = APISettings(USER_SETTINGS, DEFAULTS, IMPORT_STRINGS)

rest_framework_jwt/tests/test_views.py

Lines changed: 111 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,29 @@
1+
from calendar import timegm
2+
from datetime import datetime, timedelta
3+
import time
4+
15
from django.test import TestCase
26
from django.conf import settings
37
from django.contrib.auth.models import User
8+
49
from rest_framework import status
510
from rest_framework.compat import patterns
611
from rest_framework.test import APIClient
712

813
from rest_framework_jwt import utils
914
from rest_framework_jwt.runtests.models import CustomUser
15+
from rest_framework_jwt.settings import api_settings, DEFAULTS
1016

1117
urlpatterns = patterns(
1218
'',
1319
(r'^auth-token/$', 'rest_framework_jwt.views.obtain_jwt_token'),
20+
(r'^auth-token-refresh/$', 'rest_framework_jwt.views.refresh_jwt_token'),
1421
)
1522

23+
orig_datetime = datetime
24+
1625

17-
class ObtainJSONWebTokenTests(TestCase):
26+
class BaseTestCase(TestCase):
1827
urls = 'rest_framework_jwt.tests.test_views'
1928

2029
def setUp(self):
@@ -29,6 +38,9 @@ def setUp(self):
2938
'password': self.password
3039
}
3140

41+
42+
class ObtainJSONWebTokenTests(BaseTestCase):
43+
3244
def test_jwt_login_json(self):
3345
"""
3446
Ensure JWT login view using JSON POST works.
@@ -145,3 +157,101 @@ def test_jwt_login_json_bad_creds(self):
145157

146158
def tearDown(self):
147159
settings.AUTH_USER_MODEL = self.ORIG_AUTH_USER_MODEL
160+
161+
162+
class RefreshJSONWebTokenTests(BaseTestCase):
163+
urls = 'rest_framework_jwt.tests.test_views'
164+
165+
def setUp(self):
166+
super(RefreshJSONWebTokenTests, self).setUp()
167+
api_settings.JWT_ALLOW_REFRESH = True
168+
169+
def get_token(self):
170+
client = APIClient(enforce_csrf_checks=True)
171+
response = client.post('/auth-token/', self.data, format='json')
172+
return response.data['token']
173+
174+
def create_token(self, user, exp=None, orig_iat=None):
175+
payload = utils.jwt_payload_handler(self.user)
176+
if exp:
177+
payload['exp'] = exp
178+
179+
if orig_iat:
180+
payload['orig_iat'] = timegm(orig_iat.utctimetuple())
181+
182+
token = utils.jwt_encode_handler(payload)
183+
return token
184+
185+
def test_refresh_jwt(self):
186+
"""
187+
Test getting a refreshed token from original token works
188+
"""
189+
client = APIClient(enforce_csrf_checks=True)
190+
191+
orig_token = self.get_token()
192+
orig_token_decoded = utils.jwt_decode_handler(orig_token)
193+
194+
expected_orig_iat = timegm(datetime.utcnow().utctimetuple())
195+
196+
# Make sure 'orig_iat' exists and is the current time (give some slack)
197+
orig_iat = orig_token_decoded['orig_iat']
198+
self.assertLessEqual(orig_iat - expected_orig_iat, 1)
199+
200+
# wait a few seconds, so new token will have different exp
201+
time.sleep(2)
202+
203+
# Now try to get a refreshed token
204+
response = client.post('/auth-token-refresh/', {'token': orig_token},
205+
format='json')
206+
self.assertEqual(response.status_code, status.HTTP_200_OK)
207+
208+
new_token = response.data['token']
209+
new_token_decoded = utils.jwt_decode_handler(new_token)
210+
211+
# Make sure 'orig_iat' on the new token is same as original
212+
self.assertEquals(new_token_decoded['orig_iat'], orig_iat)
213+
self.assertGreater(new_token_decoded['exp'], orig_token_decoded['exp'])
214+
215+
def test_refresh_jwt_fails_with_expired_token(self):
216+
"""
217+
Test that using an expired token to refresh won't work
218+
"""
219+
client = APIClient(enforce_csrf_checks=True)
220+
221+
# Make an expired token..
222+
token = self.create_token(
223+
self.user,
224+
exp=datetime.utcnow() - timedelta(seconds=5),
225+
orig_iat=datetime.utcnow() - timedelta(hours=1)
226+
)
227+
228+
response = client.post('/auth-token-refresh/', {'token': token},
229+
format='json')
230+
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
231+
self.assertRegexpMatches(response.data['non_field_errors'][0],
232+
'Signature has expired')
233+
234+
def test_refresh_jwt_after_refresh_expiration(self):
235+
"""
236+
Test that token can't be refreshed after token refresh limit
237+
"""
238+
client = APIClient(enforce_csrf_checks=True)
239+
240+
orig_iat = (datetime.utcnow() - api_settings.JWT_REFRESH_EXPIRATION_DELTA -
241+
timedelta(seconds=5))
242+
token = self.create_token(
243+
self.user,
244+
exp=datetime.utcnow() + timedelta(hours=1),
245+
orig_iat=orig_iat
246+
)
247+
248+
response = client.post('/auth-token-refresh/', {'token': token},
249+
format='json')
250+
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
251+
self.assertEqual(response.data['non_field_errors'][0],
252+
'Refresh has expired')
253+
254+
def tearDown(self):
255+
# Restore original settings
256+
api_settings.JWT_ALLOW_REFRESH = \
257+
DEFAULTS['JWT_ALLOW_REFRESH']

rest_framework_jwt/utils.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import datetime
1+
from datetime import datetime
22
import jwt
33

44
from rest_framework_jwt.settings import api_settings
@@ -9,10 +9,18 @@ def jwt_payload_handler(user):
99
'user_id': user.pk,
1010
'email': user.email,
1111
'username': user.get_username(),
12-
'exp': datetime.datetime.utcnow() + api_settings.JWT_EXPIRATION_DELTA
12+
'exp': datetime.utcnow() + api_settings.JWT_EXPIRATION_DELTA
1313
}
1414

1515

16+
def jwt_get_user_id_from_payload_handler(payload):
17+
"""
18+
Override this function if user_id is formatted differently in payload
19+
"""
20+
user_id = payload.get('user_id')
21+
return user_id
22+
23+
1624
def jwt_encode_handler(payload):
1725
return jwt.encode(
1826
payload,

0 commit comments

Comments
 (0)