Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 32 additions & 46 deletions promo_code/user/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def create(self, validated_data):
other=validated_data['other'],
password=validated_data['password'],
)
user.token_version += 1
user.save()
return user
except django.core.exceptions.ValidationError as e:
raise rest_framework.serializers.ValidationError(e.messages)
Expand All @@ -80,13 +82,26 @@ class SignInSerializer(
write_only=True,
)

def validate(self, data):
email = data.get('email')
password = data.get('password')
def validate(self, attrs):
user = self.authenticate_user(attrs)

self.update_token_version(user)

data = super().validate(attrs)

refresh = rest_framework_simplejwt.tokens.RefreshToken(data['refresh'])

self.invalidate_previous_tokens(user, refresh['jti'])

return data

def authenticate_user(self, attrs):
email = attrs.get('email')
password = attrs.get('password')

if not email or not password:
raise rest_framework.serializers.ValidationError(
{'status': 'error', 'message': 'Both fields are required.'},
raise rest_framework.exceptions.ValidationError(
{'detail': 'Both email and password are required'},
code='required',
)

Expand All @@ -95,55 +110,26 @@ def validate(self, data):
email=email,
password=password,
)
if not user:
raise rest_framework.exceptions.AuthenticationFailed(
{'status': 'error', 'message': 'Invalid email or password.'},
code='authorization',
)

authenticate_kwargs = {
self.username_field: data[self.username_field],
'password': data['password'],
}
try:
authenticate_kwargs['request'] = self.context['request']
except KeyError:
pass

self.user = django.contrib.auth.authenticate(**authenticate_kwargs)

if not getattr(self.user, 'is_active', None):
if not user or not user.is_active:
raise rest_framework.exceptions.AuthenticationFailed(
self.error_messages['no_active_account'],
'no_active_account',
{'detail': 'Invalid credentials or inactive account'},
code='authentication_failed',
)

self.user.token_version += 1
self.user.save()
return user

refresh = self.get_token(self.user)
data = {
'refresh': str(refresh),
'access': str(refresh.access_token),
}

current_jti = refresh['jti']

tokens_qs = tb_models.OutstandingToken.objects.filter(
user=self.user,
)

outstanding_tokens = tokens_qs.exclude(jti=current_jti)
def invalidate_previous_tokens(self, user, current_jti):
outstanding_tokens = tb_models.OutstandingToken.objects.filter(
user=user,
).exclude(jti=current_jti)

for token in outstanding_tokens:
(
tb_models.BlacklistedToken.objects.get_or_create(
token=token,
)
)
tb_models.BlacklistedToken.objects.get_or_create(token=token)

data['token_version'] = self.user.token_version
return data
def update_token_version(self, user):
user.token_version += 1
user.save()

def get_token(self, user):
token = super().get_token(user)
Expand Down
65 changes: 44 additions & 21 deletions promo_code/user/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def test_valid_registration(self):
response.status_code,
rest_framework.status.HTTP_200_OK,
)
self.assertIn('token', response.data)
self.assertIn('access', response.data)
self.assertTrue(
user.models.User.objects.filter(
email='[email protected]',
Expand Down Expand Up @@ -391,7 +391,7 @@ def test_signin_success(self):

class JWTTests(rest_framework.test.APITestCase):
def setUp(self):

self.signup_url = django.urls.reverse('api-user:sign-up')
self.signin_url = django.urls.reverse('api-user:sign-in')
self.protected_url = django.urls.reverse('api-core:protected')
self.refresh_url = django.urls.reverse('api-user:token_refresh')
Expand Down Expand Up @@ -428,13 +428,54 @@ def test_access_protected_view_with_valid_token(self):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data['status'], 'request was permitted')

def test_refresh_token_invalidation_after_new_login(self):
def test_registration_token_invalid_after_login(self):
data = {
'email': '[email protected]',
'password': 'StrongPass123!cd',
'name': 'John',
'surname': 'Doe',
'other': {'age': 22, 'country': 'us'},
}
response = self.client.post(
self.signup_url,
data,
format='json',
)
reg_access_token = response.data['access']

self.client.credentials(
HTTP_AUTHORIZATION=f'Bearer {reg_access_token}',
)
response = self.client.get(self.protected_url)
self.assertEqual(response.status_code, 200)

login_data = {'email': data['email'], 'password': data['password']}
response = self.client.post(
self.signin_url,
login_data,
format='json',
)
login_access_token = response.data['access']

self.client.credentials(
HTTP_AUTHORIZATION=f'Bearer {reg_access_token}',
)
response = self.client.get(self.protected_url)
self.assertEqual(response.status_code, 401)

self.client.credentials(
HTTP_AUTHORIZATION=f'Bearer {login_access_token}',
)
response = self.client.get(self.protected_url)
self.assertEqual(response.status_code, 200)

def test_refresh_token_invalidation_after_new_login(self):
first_login_response = self.client.post(
self.signin_url,
self.user_data,
format='json',
)

refresh_token_v1 = first_login_response.data['refresh']

second_login_response = self.client.post(
Expand Down Expand Up @@ -493,21 +534,3 @@ def test_blacklist_storage(self):
(tb_models.OutstandingToken.objects.count()),
2,
)

def test_token_version_increment(self):
response1 = self.client.post(
self.signin_url,
self.user_data,
format='json',
)
self.assertEqual(response1.data['token_version'], 1)

response2 = self.client.post(
self.signin_url,
self.user_data,
format='json',
)
self.assertEqual(response2.data['token_version'], 2)

user_ = user.models.User.objects.get(email=self.user_data['email'])
self.assertEqual(user_.token_version, 2)
16 changes: 12 additions & 4 deletions promo_code/user/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,13 @@ def create(self, request, *args, **kwargs):
return self.handle_validation_error()

user = serializer.save()

refresh = rest_framework_simplejwt.tokens.RefreshToken.for_user(user)
refresh['token_version'] = user.token_version
access_token = refresh.access_token

return rest_framework.response.Response(
{'token': str(refresh.access_token)},
{'access': str(access_token), 'refresh': str(refresh)},
status=rest_framework.status.HTTP_200_OK,
)

Expand All @@ -49,11 +53,10 @@ class SignInView(
serializer_class = user.serializers.SignInSerializer

def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)

try:
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
response = super().post(request, *args, **kwargs)
except (
rest_framework.serializers.ValidationError,
rest_framework_simplejwt.exceptions.TokenError,
Expand All @@ -63,7 +66,12 @@ def post(self, request, *args, **kwargs):

raise rest_framework_simplejwt.exceptions.InvalidToken(str(e))

response_data = {
'access': serializer.validated_data['access'],
'refresh': serializer.validated_data['refresh'],
}

return rest_framework.response.Response(
response,
response_data,
status=rest_framework.status.HTTP_200_OK,
)