diff --git a/promo_code/business/serializers.py b/promo_code/business/serializers.py index 4e6ce4e..3c9b282 100644 --- a/promo_code/business/serializers.py +++ b/promo_code/business/serializers.py @@ -12,9 +12,9 @@ import business.constants import business.models -import business.utils.auth import business.utils.tokens import business.validators +import core.utils.auth class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer): @@ -36,12 +36,6 @@ class CompanySignUpSerializer(rest_framework.serializers.ModelSerializer): required=True, min_length=business.constants.COMPANY_EMAIL_MIN_LENGTH, max_length=business.constants.COMPANY_EMAIL_MAX_LENGTH, - validators=[ - business.validators.UniqueEmailValidator( - 'This email address is already registered.', - 'email_conflict', - ), - ], ) class Meta: @@ -50,11 +44,20 @@ class Meta: @django.db.transaction.atomic def create(self, validated_data): - company = business.models.Company.objects.create_company( - **validated_data, - ) + try: + company = business.models.Company.objects.create_company( + **validated_data, + ) + except django.db.IntegrityError: + exc = rest_framework.exceptions.APIException( + detail={ + 'email': 'This email address is already registered.', + }, + ) + exc.status_code = 409 + raise exc - return business.utils.auth.bump_company_token_version(company) + return core.utils.auth.bump_token_version(company) class CompanySignInSerializer(rest_framework.serializers.Serializer): diff --git a/promo_code/business/tests/auth/base.py b/promo_code/business/tests/auth/base.py index ca100d2..45cfd66 100644 --- a/promo_code/business/tests/auth/base.py +++ b/promo_code/business/tests/auth/base.py @@ -1,6 +1,5 @@ import django.urls import rest_framework -import rest_framework.status import rest_framework.test import business.models diff --git a/promo_code/business/tests/auth/test_authentication.py b/promo_code/business/tests/auth/test_authentication.py index 4a0d357..7a042d1 100644 --- a/promo_code/business/tests/auth/test_authentication.py +++ b/promo_code/business/tests/auth/test_authentication.py @@ -1,5 +1,4 @@ import rest_framework.status -import rest_framework.test import business.models import business.tests.auth.base diff --git a/promo_code/business/tests/auth/test_registration.py b/promo_code/business/tests/auth/test_registration.py index 0920b12..d0ac473 100644 --- a/promo_code/business/tests/auth/test_registration.py +++ b/promo_code/business/tests/auth/test_registration.py @@ -1,5 +1,4 @@ import rest_framework.status -import rest_framework.test import business.models import business.tests.auth.base diff --git a/promo_code/business/tests/auth/test_tokens.py b/promo_code/business/tests/auth/test_tokens.py index 21f726f..d47227c 100644 --- a/promo_code/business/tests/auth/test_tokens.py +++ b/promo_code/business/tests/auth/test_tokens.py @@ -1,5 +1,4 @@ import rest_framework.status -import rest_framework.test import rest_framework_simplejwt.tokens import business.models diff --git a/promo_code/business/tests/auth/test_validation.py b/promo_code/business/tests/auth/test_validation.py index 7b453f4..b96be7e 100644 --- a/promo_code/business/tests/auth/test_validation.py +++ b/promo_code/business/tests/auth/test_validation.py @@ -1,6 +1,5 @@ import parameterized import rest_framework.status -import rest_framework.test import business.models import business.tests.auth.base diff --git a/promo_code/business/utils/auth.py b/promo_code/business/utils/auth.py deleted file mode 100644 index ad1c6f1..0000000 --- a/promo_code/business/utils/auth.py +++ /dev/null @@ -1,13 +0,0 @@ -import business.models - - -def bump_company_token_version(company): - """ - Increment token_version, save it, and return the fresh instance. - """ - company = business.models.Company.objects.select_for_update().get( - id=company.id, - ) - company.token_version += 1 - company.save(update_fields=['token_version']) - return company diff --git a/promo_code/business/validators.py b/promo_code/business/validators.py index 6bcd45b..b66ce07 100644 --- a/promo_code/business/validators.py +++ b/promo_code/business/validators.py @@ -1,28 +1,6 @@ import rest_framework.exceptions import business.constants -import business.models - - -class UniqueEmailValidator: - def __init__(self, default_detail=None, default_code=None): - self.status_code = 409 - self.default_detail = ( - default_detail or 'This email address is already registered.' - ) - self.default_code = default_code or 'email_conflict' - - def __call__(self, value): - if business.models.Company.objects.filter(email=value).exists(): - exc = rest_framework.exceptions.APIException( - detail={ - 'status': 'error', - 'message': self.default_detail, - 'code': self.default_code, - }, - ) - exc.status_code = self.status_code - raise exc class PromoValidator: diff --git a/promo_code/business/views.py b/promo_code/business/views.py index 2472394..9aa7c02 100644 --- a/promo_code/business/views.py +++ b/promo_code/business/views.py @@ -10,11 +10,11 @@ import rest_framework_simplejwt.views import business.models -import business.pagination import business.permissions import business.serializers -import business.utils.auth import business.utils.tokens +import core.pagination +import core.utils.auth import user.models @@ -49,7 +49,7 @@ def post(self, request, *args, **kwargs): serializer.is_valid(raise_exception=True) company = serializer.validated_data['company'] - company = business.utils.auth.bump_company_token_version(company) + company = core.utils.auth.bump_token_version(company) return rest_framework.response.Response( business.utils.tokens.generate_company_tokens(company), @@ -75,7 +75,7 @@ class CompanyPromoListCreateView(rest_framework.generics.ListCreateAPIView): business.permissions.IsCompanyUser, ] # Pagination is only needed for GET (listing) - pagination_class = business.pagination.CustomLimitOffsetPagination + pagination_class = core.pagination.CustomLimitOffsetPagination _validated_query_params = {} diff --git a/promo_code/business/pagination.py b/promo_code/core/pagination.py similarity index 100% rename from promo_code/business/pagination.py rename to promo_code/core/pagination.py diff --git a/promo_code/core/utils/auth.py b/promo_code/core/utils/auth.py new file mode 100644 index 0000000..6f99c10 --- /dev/null +++ b/promo_code/core/utils/auth.py @@ -0,0 +1,28 @@ +import django.core.cache +import django.db.models + + +def bump_token_version( + instance: django.db.models.Model, +) -> django.db.models.Model: + """ + Atomically increments token_version for any model instance + (User or Company), invalidates the corresponding cache, + and returns the updated instance. + """ + user_type = instance.__class__.__name__.lower() + + old_token_version = instance.token_version + + instance.__class__.objects.filter(id=instance.id).update( + token_version=django.db.models.F('token_version') + 1, + ) + + old_cache_key = ( + f'auth_instance_{user_type}_{instance.id}_v{old_token_version}' + ) + django.core.cache.cache.delete(old_cache_key) + + instance.refresh_from_db() + + return instance diff --git a/promo_code/promo_code/settings.py b/promo_code/promo_code/settings.py index edc06d3..72f329b 100644 --- a/promo_code/promo_code/settings.py +++ b/promo_code/promo_code/settings.py @@ -48,6 +48,8 @@ def load_bool(name, default): AUTH_USER_MODEL = 'user.User' +AUTH_INSTANCE_CACHE_TIMEOUT = 3600 + REST_FRAMEWORK = { 'DEFAULT_AUTHENTICATION_CLASSES': [ 'user.authentication.CustomJWTAuthentication', @@ -197,6 +199,15 @@ def load_bool(name, default): ] +PASSWORD_HASHERS = [ + 'django.contrib.auth.hashers.Argon2PasswordHasher', + 'django.contrib.auth.hashers.PBKDF2PasswordHasher', + 'django.contrib.auth.hashers.PBKDF2SHA1PasswordHasher', + 'django.contrib.auth.hashers.BCryptSHA256PasswordHasher', + 'django.contrib.auth.hashers.ScryptPasswordHasher', +] + + LANGUAGE_CODE = 'en-us' TIME_ZONE = 'UTC' diff --git a/promo_code/user/authentication.py b/promo_code/user/authentication.py index c42fa37..96312af 100644 --- a/promo_code/user/authentication.py +++ b/promo_code/user/authentication.py @@ -1,14 +1,23 @@ +import django.conf +import django.core.cache import rest_framework_simplejwt.authentication import rest_framework_simplejwt.exceptions import business.models -import user.models as user_models +import user.models class CustomJWTAuthentication( rest_framework_simplejwt.authentication.JWTAuthentication, ): def authenticate(self, request): + """ + Authenticates the user or company based on a JWT token, + supporting multiple user types. + Retrieves the appropriate model instance from the token, + checks token versioning and caches the authenticated instance + for performance. + """ try: header = self.get_header(request) if header is None: @@ -20,9 +29,8 @@ def authenticate(self, request): validated_token = self.get_validated_token(raw_token) user_type = validated_token.get('user_type', 'user') - model_mapping = { - 'user': (user_models.User, 'user_id'), + 'user': (user.models.User, 'user_id'), 'company': (business.models.Company, 'company_id'), } @@ -32,21 +40,44 @@ def authenticate(self, request): ) model_class, id_field = model_mapping[user_type] - instance = model_class.objects.get( - id=validated_token.get(id_field), + instance_id = validated_token.get(id_field) + token_version = validated_token.get('token_version', 0) + + cache_key = ( + f'auth_instance_{user_type}_{instance_id}_v{token_version}' ) - if instance.token_version != validated_token.get( - 'token_version', - 0, - ): + cached_instance = django.core.cache.cache.get(cache_key) + + if cached_instance: + return (cached_instance, validated_token) + + if instance_id is None: + raise rest_framework_simplejwt.exceptions.AuthenticationFailed( + f'Missing {id_field} in token', + ) + + instance = model_class.objects.get(id=instance_id) + + if instance.token_version != token_version: raise rest_framework_simplejwt.exceptions.AuthenticationFailed( 'Token invalid', ) + cache_timeout = getattr( + django.conf.settings, + 'AUTH_INSTANCE_CACHE_TIMEOUT', + 3600, + ) + django.core.cache.cache.set( + cache_key, + instance, + timeout=cache_timeout, + ) + return (instance, validated_token) except ( - user_models.User.DoesNotExist, + user.models.User.DoesNotExist, business.models.Company.DoesNotExist, ): raise rest_framework_simplejwt.exceptions.AuthenticationFailed( diff --git a/promo_code/user/pagination.py b/promo_code/user/pagination.py deleted file mode 100644 index d079fbc..0000000 --- a/promo_code/user/pagination.py +++ /dev/null @@ -1,24 +0,0 @@ -import rest_framework.pagination -import rest_framework.response - - -class UserFeedPagination(rest_framework.pagination.LimitOffsetPagination): - default_limit = 10 - max_limit = 100 - - def get_limit(self, request): - raw_limit = request.query_params.get(self.limit_query_param) - - if raw_limit is None: - return self.default_limit - - limit = int(raw_limit) - - # Allow 0, otherwise cut by max_limit - return 0 if limit == 0 else min(limit, self.max_limit) - - def get_paginated_response(self, data): - return rest_framework.response.Response( - data, - headers={'X-Total-Count': str(self.count)}, - ) diff --git a/promo_code/user/serializers.py b/promo_code/user/serializers.py index 72199f3..cbdba3a 100644 --- a/promo_code/user/serializers.py +++ b/promo_code/user/serializers.py @@ -1,7 +1,6 @@ import django.contrib.auth.password_validation -import django.core.exceptions -import django.core.validators -import django.db.models +import django.core.cache +import django.db.transaction import pycountry import rest_framework.exceptions import rest_framework.serializers @@ -11,9 +10,9 @@ import business.constants import business.models +import core.utils.auth import user.constants import user.models -import user.validators class OtherFieldSerializer(rest_framework.serializers.Serializer): @@ -64,19 +63,11 @@ class SignUpSerializer(rest_framework.serializers.ModelSerializer): required=True, min_length=user.constants.EMAIL_MIN_LENGTH, max_length=user.constants.EMAIL_MAX_LENGTH, - validators=[ - user.validators.UniqueEmailValidator( - 'This email address is already registered.', - 'email_conflict', - ), - ], - ) - avatar_url = rest_framework.serializers.CharField( + ) + avatar_url = rest_framework.serializers.URLField( required=False, max_length=user.constants.AVATAR_URL_MAX_LENGTH, - validators=[ - django.core.validators.URLValidator(schemes=['http', 'https']), - ], + allow_null=True, ) other = OtherFieldSerializer(required=True) @@ -91,21 +82,20 @@ class Meta: 'password', ) + @django.db.transaction.atomic def create(self, validated_data): try: - user_ = user.models.User.objects.create_user( - email=validated_data['email'], - name=validated_data['name'], - surname=validated_data['surname'], - avatar_url=validated_data.get('avatar_url'), - other=validated_data['other'], - password=validated_data['password'], + user_ = user.models.User.objects.create_user(**validated_data) + except django.db.IntegrityError: + exc = rest_framework.exceptions.APIException( + detail={ + 'email': 'This email address is already registered.', + }, ) - user_.token_version += 1 - user_.save() - return user_ - except django.core.exceptions.ValidationError as e: - raise rest_framework.serializers.ValidationError(e.messages) + exc.status_code = 409 + raise exc + + return core.utils.auth.bump_token_version(user_) class SignInSerializer( @@ -120,14 +110,18 @@ class SignInSerializer( def validate(self, attrs): user = self.authenticate_user(attrs) - user.token_version = django.db.models.F('token_version') + 1 - user.save(update_fields=['token_version']) + user = core.utils.auth.bump_token_version(user) - data = super().validate(attrs) + self.user = user - refresh = rest_framework_simplejwt.tokens.RefreshToken(data['refresh']) + data = super().validate(attrs) - self.blacklist_other_tokens(user, refresh['jti']) + refresh = data.get('refresh') + if refresh: + refresh_token = rest_framework_simplejwt.tokens.RefreshToken( + refresh, + ) + self.blacklist_other_tokens(user, refresh_token['jti']) return data @@ -187,12 +181,6 @@ class UserProfileSerializer(rest_framework.serializers.ModelSerializer): required=False, min_length=user.constants.EMAIL_MIN_LENGTH, max_length=user.constants.EMAIL_MAX_LENGTH, - validators=[ - user.validators.UniqueEmailValidator( - 'This email address is already registered.', - 'email_conflict', - ), - ], ) password = rest_framework.serializers.CharField( write_only=True, @@ -202,12 +190,10 @@ class UserProfileSerializer(rest_framework.serializers.ModelSerializer): min_length=user.constants.PASSWORD_MIN_LENGTH, style={'input_type': 'password'}, ) - avatar_url = rest_framework.serializers.CharField( + avatar_url = rest_framework.serializers.URLField( required=False, max_length=user.constants.AVATAR_URL_MAX_LENGTH, - validators=[ - django.core.validators.URLValidator(schemes=['http', 'https']), - ], + allow_null=True, ) other = OtherFieldSerializer(required=False) @@ -233,10 +219,28 @@ def update(self, instance, validated_data): if other_data is not None: instance.other = other_data + if ( + 'email' in validated_data + and user.models.User.objects.filter( + email=validated_data['email'], + ) + .exclude(id=instance.id) + .exists() + ): + raise rest_framework.exceptions.ValidationError( + {'email': 'This email address is already registered.'}, + ) + for attr, value in validated_data.items(): setattr(instance, attr, value) instance.save() + + user_type = instance.__class__.__name__.lower() + token_version = instance.token_version + + cache_key = f'auth_instance_{user_type}_{instance.id}_v{token_version}' + django.core.cache.cache.delete(cache_key) return instance def to_representation(self, instance): diff --git a/promo_code/user/tests/user/base.py b/promo_code/user/tests/user/base.py index 9a3cf82..d010219 100644 --- a/promo_code/user/tests/user/base.py +++ b/promo_code/user/tests/user/base.py @@ -1,5 +1,4 @@ import django.conf -import django.core.cache import django.urls import django_redis import rest_framework.test diff --git a/promo_code/user/validators.py b/promo_code/user/validators.py deleted file mode 100644 index cbd8941..0000000 --- a/promo_code/user/validators.py +++ /dev/null @@ -1,24 +0,0 @@ -import rest_framework.exceptions - -import user.models - - -class UniqueEmailValidator: - def __init__(self, default_detail=None, default_code=None): - self.status_code = 409 - self.default_detail = ( - default_detail or 'This email address is already registered.' - ) - self.default_code = default_code or 'email_conflict' - - def __call__(self, value): - if user.models.User.objects.filter(email=value).exists(): - exc = rest_framework.exceptions.APIException( - detail={ - 'status': 'error', - 'message': self.default_detail, - 'code': self.default_code, - }, - ) - exc.status_code = self.status_code - raise exc diff --git a/promo_code/user/views.py b/promo_code/user/views.py index 70bd8d2..34936fb 100644 --- a/promo_code/user/views.py +++ b/promo_code/user/views.py @@ -13,9 +13,9 @@ import business.constants import business.models +import core.pagination import user.antifraud_service import user.models -import user.pagination import user.permissions import user.serializers @@ -109,7 +109,7 @@ class UserPromoDetailView(rest_framework.generics.RetrieveAPIView): class UserFeedView(rest_framework.generics.ListAPIView): serializer_class = user.serializers.PromoFeedSerializer permission_classes = [rest_framework.permissions.IsAuthenticated] - pagination_class = user.pagination.UserFeedPagination + pagination_class = core.pagination.CustomLimitOffsetPagination def get_queryset(self): user = self.request.user @@ -287,7 +287,7 @@ def delete(self, request, id): class PromoCommentListCreateView(rest_framework.generics.ListCreateAPIView): permission_classes = [rest_framework.permissions.IsAuthenticated] - pagination_class = user.pagination.UserFeedPagination + pagination_class = core.pagination.CustomLimitOffsetPagination def get_serializer_class(self): if self.request.method == 'POST': @@ -537,7 +537,7 @@ class PromoHistoryView(rest_framework.generics.ListAPIView): serializer_class = user.serializers.UserPromoDetailSerializer permission_classes = [rest_framework.permissions.IsAuthenticated] - pagination_class = user.pagination.UserFeedPagination + pagination_class = core.pagination.CustomLimitOffsetPagination def get_queryset(self): user = self.request.user diff --git a/requirements/prod.txt b/requirements/prod.txt index 8755a4f..18ecb42 100644 --- a/requirements/prod.txt +++ b/requirements/prod.txt @@ -1,3 +1,4 @@ +argon2-cffi==25.1.0 django==5.2 django-redis==6.0.0 djangorestframework==3.15.2 @@ -7,4 +8,4 @@ psycopg2-binary==2.9.10 pycountry==24.6.1 python-dotenv==1.0.1 requests==2.32.4 -parameterized==0.9.0 \ No newline at end of file +parameterized==0.9.0