Skip to content

Commit f471f58

Browse files
committed
Add custom validators, permission classes, and PermissionTestModel for testing permissions
1 parent d3dd45b commit f471f58

File tree

5 files changed

+229
-2
lines changed

5 files changed

+229
-2
lines changed

rest_framework/fields.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,46 @@ def to_internal_value(self, data):
882882
return super().to_internal_value(data)
883883

884884

885+
class AlphabeticFieldValidator:
886+
"""
887+
Custom validator to ensure that a field only contains alphabetic characters and spaces.
888+
"""
889+
def __call__(self, value):
890+
if not isinstance(value, str):
891+
raise ValueError("This field must be a string.")
892+
if value == "":
893+
raise ValueError("This field must contain only alphabetic characters and spaces.")
894+
if not re.match(r'^[A-Za-z ]*$', value):
895+
raise ValueError("This field must contain only alphabetic characters and spaces.")
896+
897+
class AlphanumericFieldValidator:
898+
"""
899+
Custom validator to ensure the field contains only alphanumeric characters (letters and numbers).
900+
"""
901+
def __call__(self, value):
902+
if not isinstance(value, str):
903+
raise ValueError("This field must be a string.")
904+
if value == "":
905+
raise ValueError("This field must contain only alphanumeric characters (letters and numbers).")
906+
if not re.match(r'^[A-Za-z0-9]*$', value):
907+
raise ValueError("This field must contain only alphanumeric characters (letters and numbers).")
908+
909+
class CustomLengthValidator:
910+
"""
911+
Custom validator to ensure the length of a string is within specified limits.
912+
"""
913+
def __init__(self, min_length=0, max_length=None):
914+
self.min_length = min_length
915+
self.max_length = max_length
916+
917+
def __call__(self, value):
918+
if len(value) < self.min_length:
919+
raise ValueError(f"This field must be at least {self.min_length} characters long.")
920+
921+
if self.max_length is not None and len(value) > self.max_length:
922+
raise ValueError(f"This field must be no more than {self.max_length} characters long.")
923+
924+
885925
# Number types...
886926

887927
class IntegerField(Field):

rest_framework/permissions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,31 @@ def has_permission(self, request, view):
172172
request.user.is_authenticated
173173
)
174174

175+
class IsAdminUserOrReadOnly(BasePermission):
176+
"""
177+
Custom permission to only allow admin users to edit an object.
178+
"""
179+
180+
def has_permission(self, request, view):
181+
# Allow any user to view the object
182+
if request.method in ['GET', 'HEAD', 'OPTIONS']:
183+
return True
184+
# Only allow admin users to modify the object
185+
return request.user and request.user.is_staff
186+
187+
188+
class IsOwner(BasePermission):
189+
"""
190+
Custom permission to only allow owners of an object to edit it.
191+
"""
192+
193+
def has_object_permission(self, request, view, obj):
194+
# Allow read-only access to any request
195+
if request.method in ['GET', 'HEAD', 'OPTIONS']:
196+
return True
197+
# Write permissions are only allowed to the owner of the object
198+
return obj.owner == request.user
199+
175200

176201
class DjangoModelPermissions(BasePermission):
177202
"""

tests/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,3 +150,11 @@ def __new__(cls, *args, **kwargs):
150150
help_text='OneToOneTarget',
151151
verbose_name='OneToOneTarget',
152152
on_delete=models.CASCADE)
153+
154+
155+
class OwnershipTestModel(models.Model):
156+
owner = models.ForeignKey(User, on_delete=models.CASCADE, related_name='ownership_test_models')
157+
title = models.CharField(max_length=100)
158+
159+
def __str__(self):
160+
return self.title

tests/test_fields.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import rest_framework
2626
from rest_framework import exceptions, serializers
2727
from rest_framework.fields import (
28-
BuiltinSignatureError, DjangoImageField, SkipField, empty,
28+
AlphabeticFieldValidator, AlphanumericFieldValidator, BuiltinSignatureError, CustomLengthValidator, DjangoImageField, SkipField, empty,
2929
is_simple_callable
3030
)
3131
from tests.models import UUIDForeignKeyTarget
@@ -1061,6 +1061,112 @@ class TestFilePathField(FieldValues):
10611061
)
10621062

10631063

1064+
class TestAlphabeticField:
1065+
valid_inputs = {
1066+
'John Doe': 'John Doe',
1067+
'Alice': 'Alice',
1068+
'Bob Marley': 'Bob Marley',
1069+
}
1070+
invalid_inputs = {
1071+
'John123': ['This field must contain only alphabetic characters and spaces.'],
1072+
'Alice!': ['This field must contain only alphabetic characters and spaces.'],
1073+
'': ['This field must contain only alphabetic characters and spaces.'],
1074+
}
1075+
non_string_inputs = [
1076+
123, # Integer
1077+
45.67, # Float
1078+
None, # NoneType
1079+
[], # Empty list
1080+
{}, # Empty dict
1081+
set() # Empty set
1082+
]
1083+
1084+
def test_valid_inputs(self):
1085+
validator = AlphabeticFieldValidator()
1086+
for value in self.valid_inputs.keys():
1087+
validator(value)
1088+
1089+
def test_invalid_inputs(self):
1090+
validator = AlphabeticFieldValidator()
1091+
for value, expected_errors in self.invalid_inputs.items():
1092+
with pytest.raises(ValueError) as excinfo:
1093+
validator(value)
1094+
assert str(excinfo.value) == expected_errors[0]
1095+
1096+
def test_non_string_inputs(self):
1097+
validator = AlphabeticFieldValidator()
1098+
for value in self.non_string_inputs:
1099+
with pytest.raises(ValueError) as excinfo:
1100+
validator(value)
1101+
assert str(excinfo.value) == "This field must be a string."
1102+
1103+
1104+
class TestAlphanumericField:
1105+
valid_inputs = {
1106+
'John123': 'John123',
1107+
'Alice007': 'Alice007',
1108+
'Bob1990': 'Bob1990',
1109+
}
1110+
invalid_inputs = {
1111+
'John!': ['This field must contain only alphanumeric characters (letters and numbers).'],
1112+
'Alice 007': ['This field must contain only alphanumeric characters (letters and numbers).'],
1113+
'': ['This field must contain only alphanumeric characters (letters and numbers).'],
1114+
}
1115+
non_string_inputs = [
1116+
123, # Integer
1117+
45.67, # Float
1118+
None, # NoneType
1119+
[], # Empty list
1120+
{}, # Empty dict
1121+
set() # Empty set
1122+
]
1123+
1124+
def test_valid_inputs(self):
1125+
validator = AlphanumericFieldValidator()
1126+
for value in self.valid_inputs.keys():
1127+
validator(value)
1128+
1129+
def test_invalid_inputs(self):
1130+
validator = AlphanumericFieldValidator()
1131+
for value, expected_errors in self.invalid_inputs.items():
1132+
with pytest.raises(ValueError) as excinfo:
1133+
validator(value)
1134+
assert str(excinfo.value) == expected_errors[0]
1135+
1136+
def test_non_string_inputs(self):
1137+
validator = AlphanumericFieldValidator()
1138+
for value in self.non_string_inputs:
1139+
with pytest.raises(ValueError) as excinfo:
1140+
validator(value)
1141+
assert str(excinfo.value) == "This field must be a string."
1142+
1143+
class TestCustomLengthField:
1144+
"""
1145+
Valid and invalid values for `CustomLengthValidator`.
1146+
"""
1147+
valid_inputs = {
1148+
'abc': 'abc', # 3 characters
1149+
'abcdefghij': 'abcdefghij', # 10 characters
1150+
}
1151+
invalid_inputs = {
1152+
'ab': ['This field must be at least 3 characters long.'], # Too short
1153+
'abcdefghijk': ['This field must be no more than 10 characters long.'], # Too long
1154+
}
1155+
field = str
1156+
1157+
def test_valid_inputs(self):
1158+
validator = CustomLengthValidator(min_length=3, max_length=10)
1159+
for value in self.valid_inputs.keys():
1160+
validator(value)
1161+
1162+
def test_invalid_inputs(self):
1163+
validator = CustomLengthValidator(min_length=3, max_length=10)
1164+
for value, expected_errors in self.invalid_inputs.items():
1165+
with pytest.raises(ValueError) as excinfo:
1166+
validator(value)
1167+
assert str(excinfo.value) == expected_errors[0]
1168+
1169+
10641170
# Number types...
10651171

10661172
class TestIntegerField(FieldValues):

tests/test_permissions.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
)
1515
from rest_framework.routers import DefaultRouter
1616
from rest_framework.test import APIRequestFactory
17-
from tests.models import BasicModel
17+
from tests.models import BasicModel, OwnershipTestModel
1818

1919
factory = APIRequestFactory()
2020

@@ -772,3 +772,51 @@ def test_filtering_permissions(self):
772772
]
773773

774774
assert filtered_permissions == expected_permissions
775+
776+
777+
class PermissionTests(TestCase):
778+
def setUp(self):
779+
self.factory = APIRequestFactory()
780+
self.admin_user = User.objects.create_user(username='admin', password='password', is_staff=True)
781+
self.regular_user = User.objects.create_user(username='user', password='password')
782+
self.anonymous_user = AnonymousUser()
783+
784+
def test_is_admin_user_or_read_only_allow_read(self):
785+
request = self.factory.get('/1', format='json')
786+
request.user = self.anonymous_user
787+
permission = permissions.IsAdminUserOrReadOnly()
788+
self.assertTrue(permission.has_permission(request, None))
789+
790+
request.user = self.admin_user
791+
self.assertTrue(permission.has_permission(request, None))
792+
793+
def test_is_admin_user_or_read_only_allow_write(self):
794+
request = self.factory.post('/1', format='json')
795+
request.user = self.admin_user
796+
permission = permissions.IsAdminUserOrReadOnly()
797+
self.assertTrue(permission.has_permission(request, None))
798+
799+
request.user = self.regular_user
800+
self.assertFalse(permission.has_permission(request, None))
801+
802+
def test_is_owner_permission(self):
803+
obj = OwnershipTestModel.objects.create(owner=self.admin_user, title='Test Title')
804+
805+
request = self.factory.post('/1', format='json')
806+
request.user = self.admin_user
807+
permission = permissions.IsOwner()
808+
self.assertTrue(permission.has_object_permission(request, None, obj))
809+
810+
request.user = self.regular_user
811+
self.assertFalse(permission.has_object_permission(request, None, obj))
812+
813+
def test_is_owner_read_access(self):
814+
obj = OwnershipTestModel.objects.create(owner=self.admin_user, title='Test Title')
815+
816+
request = self.factory.get('/1', format='json')
817+
request.user = self.regular_user
818+
permission = permissions.IsOwner()
819+
self.assertTrue(permission.has_object_permission(request, None, obj))
820+
821+
request.user = self.admin_user
822+
self.assertTrue(permission.has_object_permission(request, None, obj))

0 commit comments

Comments
 (0)