Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions promo_code/business/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

import business.constants
import business.models
import business.utils.auth
import business.utils.tokens
import business.validators
import core.utils.auth


class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
Expand All @@ -36,12 +36,6 @@ class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer):
required=True,
min_length=business.constants.COMPANY_EMAIL_MIN_LENGTH,
max_length=business.constants.COMPANY_EMAIL_MAX_LENGTH,
validators=[
business.validators.UniqueEmailValidator(
'This email address is already registered.',
'email_conflict',
),
],
)

class Meta:
Expand All @@ -50,11 +44,20 @@ class Meta:

@django.db.transaction.atomic
def create(self, validated_data):
company = business.models.Company.objects.create_company(
**validated_data,
)
try:
company = business.models.Company.objects.create_company(
**validated_data,
)
except django.db.IntegrityError:
exc = rest_framework.exceptions.APIException(
detail={
'email': 'This email address is already registered.',
},
)
exc.status_code = 409
raise exc

return business.utils.auth.bump_company_token_version(company)
return core.utils.auth.bump_token_version(company)


class CompanySignInSerializer(rest_framework.serializers.Serializer):
Expand Down
1 change: 0 additions & 1 deletion promo_code/business/tests/auth/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import django.urls
import rest_framework
import rest_framework.status
import rest_framework.test

import business.models
Expand Down
1 change: 0 additions & 1 deletion promo_code/business/tests/auth/test_authentication.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import rest_framework.status
import rest_framework.test

import business.models
import business.tests.auth.base
Expand Down
1 change: 0 additions & 1 deletion promo_code/business/tests/auth/test_registration.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import rest_framework.status
import rest_framework.test

import business.models
import business.tests.auth.base
Expand Down
1 change: 0 additions & 1 deletion promo_code/business/tests/auth/test_tokens.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import rest_framework.status
import rest_framework.test
import rest_framework_simplejwt.tokens

import business.models
Expand Down
1 change: 0 additions & 1 deletion promo_code/business/tests/auth/test_validation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import parameterized
import rest_framework.status
import rest_framework.test

import business.models
import business.tests.auth.base
Expand Down
13 changes: 0 additions & 13 deletions promo_code/business/utils/auth.py

This file was deleted.

22 changes: 0 additions & 22 deletions promo_code/business/validators.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,6 @@
import rest_framework.exceptions

import business.constants
import business.models


class UniqueEmailValidator:
def __init__(self, default_detail=None, default_code=None):
self.status_code = 409
self.default_detail = (
default_detail or 'This email address is already registered.'
)
self.default_code = default_code or 'email_conflict'

def __call__(self, value):
if business.models.Company.objects.filter(email=value).exists():
exc = rest_framework.exceptions.APIException(
detail={
'status': 'error',
'message': self.default_detail,
'code': self.default_code,
},
)
exc.status_code = self.status_code
raise exc


class PromoValidator:
Expand Down
8 changes: 4 additions & 4 deletions promo_code/business/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
import rest_framework_simplejwt.views

import business.models
import business.pagination
import business.permissions
import business.serializers
import business.utils.auth
import business.utils.tokens
import core.pagination
import core.utils.auth
import user.models


Expand Down Expand Up @@ -49,7 +49,7 @@ def post(self, request, *args, **kwargs):
serializer.is_valid(raise_exception=True)

company = serializer.validated_data['company']
company = business.utils.auth.bump_company_token_version(company)
company = core.utils.auth.bump_token_version(company)

return rest_framework.response.Response(
business.utils.tokens.generate_company_tokens(company),
Expand All @@ -75,7 +75,7 @@ class CompanyPromoListCreateView(rest_framework.generics.ListCreateAPIView):
business.permissions.IsCompanyUser,
]
# Pagination is only needed for GET (listing)
pagination_class = business.pagination.CustomLimitOffsetPagination
pagination_class = core.pagination.CustomLimitOffsetPagination

_validated_query_params = {}

Expand Down
File renamed without changes.
28 changes: 28 additions & 0 deletions promo_code/core/utils/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import django.core.cache
import django.db.models


def bump_token_version(
instance: django.db.models.Model,
) -> django.db.models.Model:
"""
Atomically increments token_version for any model instance
(User or Company), invalidates the corresponding cache,
and returns the updated instance.
"""
user_type = instance.__class__.__name__.lower()

old_token_version = instance.token_version

instance.__class__.objects.filter(id=instance.id).update(
token_version=django.db.models.F('token_version') + 1,
)

old_cache_key = (
f'auth_instance_{user_type}_{instance.id}_v{old_token_version}'
)
django.core.cache.cache.delete(old_cache_key)

instance.refresh_from_db()

return instance
11 changes: 11 additions & 0 deletions promo_code/promo_code/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def load_bool(name, default):

AUTH_USER_MODEL = 'user.User'

AUTH_INSTANCE_CACHE_TIMEOUT = 3600

REST_FRAMEWORK = {
'DEFAULT_AUTHENTICATION_CLASSES': [
'user.authentication.CustomJWTAuthentication',
Expand Down Expand Up @@ -197,6 +199,15 @@ def load_bool(name, default):
]


PASSWORD_HASHERS = [
'django.contrib.auth.hashers.Argon2PasswordHasher',
'django.contrib.auth.hashers.PBKDF2PasswordHasher',
'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher',
'django.contrib.auth.hashers.BCryptSHA256PasswordHasher',
'django.contrib.auth.hashers.ScryptPasswordHasher',
]


LANGUAGE_CODE = 'en-us'

TIME_ZONE = 'UTC'
Expand Down
51 changes: 41 additions & 10 deletions promo_code/user/authentication.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,23 @@
import django.conf
import django.core.cache
import rest_framework_simplejwt.authentication
import rest_framework_simplejwt.exceptions

import business.models
import user.models as user_models
import user.models


class CustomJWTAuthentication(
rest_framework_simplejwt.authentication.JWTAuthentication,
):
def authenticate(self, request):
"""
Authenticates the user or company based on a JWT token,
supporting multiple user types.
Retrieves the appropriate model instance from the token,
checks token versioning and caches the authenticated instance
for performance.
"""
try:
header = self.get_header(request)
if header is None:
Expand All @@ -20,9 +29,8 @@ def authenticate(self, request):

validated_token = self.get_validated_token(raw_token)
user_type = validated_token.get('user_type', 'user')

model_mapping = {
'user': (user_models.User, 'user_id'),
'user': (user.models.User, 'user_id'),
'company': (business.models.Company, 'company_id'),
}

Expand All @@ -32,21 +40,44 @@ def authenticate(self, request):
)

model_class, id_field = model_mapping[user_type]
instance = model_class.objects.get(
id=validated_token.get(id_field),
instance_id = validated_token.get(id_field)
token_version = validated_token.get('token_version', 0)

cache_key = (
f'auth_instance_{user_type}_{instance_id}_v{token_version}'
)
if instance.token_version != validated_token.get(
'token_version',
0,
):
cached_instance = django.core.cache.cache.get(cache_key)

if cached_instance:
return (cached_instance, validated_token)

if instance_id is None:
raise rest_framework_simplejwt.exceptions.AuthenticationFailed(
f'Missing {id_field} in token',
)

instance = model_class.objects.get(id=instance_id)

if instance.token_version != token_version:
raise rest_framework_simplejwt.exceptions.AuthenticationFailed(
'Token invalid',
)

cache_timeout = getattr(
django.conf.settings,
'AUTH_INSTANCE_CACHE_TIMEOUT',
3600,
)
django.core.cache.cache.set(
cache_key,
instance,
timeout=cache_timeout,
)

return (instance, validated_token)

except (
user_models.User.DoesNotExist,
user.models.User.DoesNotExist,
business.models.Company.DoesNotExist,
):
raise rest_framework_simplejwt.exceptions.AuthenticationFailed(
Expand Down
24 changes: 0 additions & 24 deletions promo_code/user/pagination.py

This file was deleted.

Loading