Skip to content

Commit 253f0e5

Browse files
committed
refactor: Refactor permissions to allow list
1 parent dbdcb20 commit 253f0e5

File tree

3 files changed

+152
-150
lines changed

3 files changed

+152
-150
lines changed

rest_framework/exceptions.py

Lines changed: 56 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
In addition, Django's built in 403 and 404 exceptions are handled.
55
(`django.http.Http404` and `django.core.exceptions.PermissionDenied`)
66
"""
7+
78
import math
89

910
from django.http import JsonResponse
@@ -21,23 +22,20 @@ def _get_error_details(data, default_code=None):
2122
lazy translation strings or strings into `ErrorDetail`.
2223
"""
2324
if isinstance(data, (list, tuple)):
24-
ret = [
25-
_get_error_details(item, default_code) for item in data
26-
]
25+
ret = [_get_error_details(item, default_code) for item in data]
2726
if isinstance(data, ReturnList):
2827
return ReturnList(ret, serializer=data.serializer)
2928
return ret
3029
elif isinstance(data, dict):
3130
ret = {
32-
key: _get_error_details(value, default_code)
33-
for key, value in data.items()
31+
key: _get_error_details(value, default_code) for key, value in data.items()
3432
}
3533
if isinstance(data, ReturnDict):
3634
return ReturnDict(ret, serializer=data.serializer)
3735
return ret
3836

3937
text = force_str(data)
40-
code = getattr(data, 'code', default_code)
38+
code = getattr(data, "code", default_code)
4139
return ErrorDetail(text, code)
4240

4341

@@ -54,16 +52,14 @@ def _get_full_details(detail):
5452
return [_get_full_details(item) for item in detail]
5553
elif isinstance(detail, dict):
5654
return {key: _get_full_details(value) for key, value in detail.items()}
57-
return {
58-
'message': detail,
59-
'code': detail.code
60-
}
55+
return {"message": detail, "code": detail.code}
6156

6257

6358
class ErrorDetail(str):
6459
"""
6560
A string-like object that can additionally have a code.
6661
"""
62+
6763
code = None
6864

6965
def __new__(cls, string, code=None):
@@ -87,7 +83,7 @@ def __ne__(self, other):
8783
return not result
8884

8985
def __repr__(self):
90-
return 'ErrorDetail(string=%r, code=%r)' % (
86+
return "ErrorDetail(string=%r, code=%r)" % (
9187
str(self),
9288
self.code,
9389
)
@@ -101,11 +97,23 @@ class APIException(Exception):
10197
Base class for REST framework exceptions.
10298
Subclasses should provide `.status_code` and `.default_detail` properties.
10399
"""
100+
104101
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
105-
default_detail = _('A server error occurred.')
106-
default_code = 'error'
102+
default_detail = _("A server error occurred.")
103+
default_code = "error"
107104

108105
def __init__(self, detail=None, code=None):
106+
if (
107+
isinstance(detail, tuple)
108+
and isinstance(code, tuple)
109+
and len(detail) == len(code)
110+
):
111+
self.detail = [
112+
_get_error_details(d or self.default_detail, c or self.default_code)
113+
for d, c in zip(detail, code)
114+
]
115+
return
116+
109117
if detail is None:
110118
detail = self.default_detail
111119
if code is None:
@@ -140,10 +148,11 @@ def get_full_details(self):
140148
# from rest_framework import serializers
141149
# raise serializers.ValidationError('Value was invalid')
142150

151+
143152
class ValidationError(APIException):
144153
status_code = status.HTTP_400_BAD_REQUEST
145-
default_detail = _('Invalid input.')
146-
default_code = 'invalid'
154+
default_detail = _("Invalid input.")
155+
default_code = "invalid"
147156

148157
def __init__(self, detail=None, code=None):
149158
if detail is None:
@@ -163,38 +172,38 @@ def __init__(self, detail=None, code=None):
163172

164173
class ParseError(APIException):
165174
status_code = status.HTTP_400_BAD_REQUEST
166-
default_detail = _('Malformed request.')
167-
default_code = 'parse_error'
175+
default_detail = _("Malformed request.")
176+
default_code = "parse_error"
168177

169178

170179
class AuthenticationFailed(APIException):
171180
status_code = status.HTTP_401_UNAUTHORIZED
172-
default_detail = _('Incorrect authentication credentials.')
173-
default_code = 'authentication_failed'
181+
default_detail = _("Incorrect authentication credentials.")
182+
default_code = "authentication_failed"
174183

175184

176185
class NotAuthenticated(APIException):
177186
status_code = status.HTTP_401_UNAUTHORIZED
178-
default_detail = _('Authentication credentials were not provided.')
179-
default_code = 'not_authenticated'
187+
default_detail = _("Authentication credentials were not provided.")
188+
default_code = "not_authenticated"
180189

181190

182191
class PermissionDenied(APIException):
183192
status_code = status.HTTP_403_FORBIDDEN
184-
default_detail = _('You do not have permission to perform this action.')
185-
default_code = 'permission_denied'
193+
default_detail = _("You do not have permission to perform this action.")
194+
default_code = "permission_denied"
186195

187196

188197
class NotFound(APIException):
189198
status_code = status.HTTP_404_NOT_FOUND
190-
default_detail = _('Not found.')
191-
default_code = 'not_found'
199+
default_detail = _("Not found.")
200+
default_code = "not_found"
192201

193202

194203
class MethodNotAllowed(APIException):
195204
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
196205
default_detail = _('Method "{method}" not allowed.')
197-
default_code = 'method_not_allowed'
206+
default_code = "method_not_allowed"
198207

199208
def __init__(self, method, detail=None, code=None):
200209
if detail is None:
@@ -204,8 +213,8 @@ def __init__(self, method, detail=None, code=None):
204213

205214
class NotAcceptable(APIException):
206215
status_code = status.HTTP_406_NOT_ACCEPTABLE
207-
default_detail = _('Could not satisfy the request Accept header.')
208-
default_code = 'not_acceptable'
216+
default_detail = _("Could not satisfy the request Accept header.")
217+
default_code = "not_acceptable"
209218

210219
def __init__(self, detail=None, code=None, available_renderers=None):
211220
self.available_renderers = available_renderers
@@ -215,7 +224,7 @@ def __init__(self, detail=None, code=None, available_renderers=None):
215224
class UnsupportedMediaType(APIException):
216225
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
217226
default_detail = _('Unsupported media type "{media_type}" in request.')
218-
default_code = 'unsupported_media_type'
227+
default_code = "unsupported_media_type"
219228

220229
def __init__(self, media_type, detail=None, code=None):
221230
if detail is None:
@@ -225,21 +234,28 @@ def __init__(self, media_type, detail=None, code=None):
225234

226235
class Throttled(APIException):
227236
status_code = status.HTTP_429_TOO_MANY_REQUESTS
228-
default_detail = _('Request was throttled.')
229-
extra_detail_singular = _('Expected available in {wait} second.')
230-
extra_detail_plural = _('Expected available in {wait} seconds.')
231-
default_code = 'throttled'
237+
default_detail = _("Request was throttled.")
238+
extra_detail_singular = _("Expected available in {wait} second.")
239+
extra_detail_plural = _("Expected available in {wait} seconds.")
240+
default_code = "throttled"
232241

233242
def __init__(self, wait=None, detail=None, code=None):
234243
if detail is None:
235244
detail = force_str(self.default_detail)
236245
if wait is not None:
237246
wait = math.ceil(wait)
238-
detail = ' '.join((
239-
detail,
240-
force_str(ngettext(self.extra_detail_singular.format(wait=wait),
241-
self.extra_detail_plural.format(wait=wait),
242-
wait))))
247+
detail = " ".join(
248+
(
249+
detail,
250+
force_str(
251+
ngettext(
252+
self.extra_detail_singular.format(wait=wait),
253+
self.extra_detail_plural.format(wait=wait),
254+
wait,
255+
)
256+
),
257+
)
258+
)
243259
self.wait = wait
244260
super().__init__(detail, code)
245261

@@ -248,17 +264,13 @@ def server_error(request, *args, **kwargs):
248264
"""
249265
Generic 500 error handler.
250266
"""
251-
data = {
252-
'error': 'Server Error (500)'
253-
}
267+
data = {"error": "Server Error (500)"}
254268
return JsonResponse(data, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
255269

256270

257271
def bad_request(request, exception, *args, **kwargs):
258272
"""
259273
Generic 400 error handler.
260274
"""
261-
data = {
262-
'error': 'Bad Request (400)'
263-
}
275+
data = {"error": "Bad Request (400)"}
264276
return JsonResponse(data, status=status.HTTP_400_BAD_REQUEST)

rest_framework/permissions.py

Lines changed: 50 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Provides a set of pluggable permission policies.
33
"""
44
from django.http import Http404
5-
from django.utils.translation import gettext_lazy as _
65

76
from rest_framework import exceptions
87

@@ -59,61 +58,69 @@ def __hash__(self):
5958
return hash((self.operator_class, self.op1_class, self.op2_class))
6059

6160

62-
class AND:
63-
def __init__(self, op1, op2):
64-
self.op1 = op1
65-
self.op2 = op2
66-
self.message = None
61+
class OperatorBase:
62+
def __init__(self, *permissions):
63+
self._permissions = permissions
6764

68-
def has_permission(self, request, view):
69-
if not self.op1.has_permission(request, view):
70-
self.message = getattr(self.op1, 'message', None)
71-
return False
7265

73-
if not self.op2.has_permission(request, view):
74-
self.message = getattr(self.op2, 'message', None)
75-
return False
66+
class AND(OperatorBase):
7667

68+
def has_permission(self, request, view):
69+
for perm in self._permissions:
70+
if not perm.has_permission(request, view):
71+
self._set_message_and_code(perm)
72+
return False
7773
return True
7874

7975
def has_object_permission(self, request, view, obj):
80-
if not self.op1.has_object_permission(request, view, obj):
81-
self.message = getattr(self.op1, 'message', None)
82-
return False
83-
84-
if not self.op2.has_object_permission(request, view, obj):
85-
self.message = getattr(self.op2, 'message', None)
86-
return False
87-
76+
for perm in self._permissions:
77+
if not perm.has_object_permission(request, view, obj):
78+
self._set_message_and_code(perm)
79+
return False
8880
return True
8981

82+
def _set_message_and_code(self, perm):
83+
self.message = getattr(perm, 'message', None)
84+
self.code = getattr(perm, 'code', None)
9085

91-
class OR:
92-
def __init__(self, op1, op2):
93-
self.op1 = op1
94-
self.op2 = op2
95-
self.message1 = getattr(op1, 'message', None)
96-
self.message2 = getattr(op2, 'message', None)
97-
self.message = self.message1 or self.message2
98-
if self.message1 and self.message2:
99-
self.message = '"{0}" {1} "{2}"'.format(
100-
self.message1, _('OR'), self.message2,
101-
)
86+
87+
class OR(OperatorBase):
10288

10389
def has_permission(self, request, view):
104-
return (
105-
self.op1.has_permission(request, view) or
106-
self.op2.has_permission(request, view)
107-
)
90+
collector = ResultCollector()
91+
for perm in self._permissions:
92+
if perm.has_permission(request, view):
93+
return True
94+
else:
95+
collector.add_message_and_code(perm)
96+
collector.finalize(self)
97+
return False
10898

10999
def has_object_permission(self, request, view, obj):
110-
return (
111-
self.op1.has_permission(request, view)
112-
and self.op1.has_object_permission(request, view, obj)
113-
) or (
114-
self.op2.has_permission(request, view)
115-
and self.op2.has_object_permission(request, view, obj)
116-
)
100+
collector = ResultCollector()
101+
for perm in self._permissions:
102+
if perm.has_permission(request, view) and perm.has_object_permission(request, view, obj):
103+
return True
104+
else:
105+
collector.add_message_and_code(perm)
106+
collector.finalize(self)
107+
return False
108+
109+
110+
class ResultCollector:
111+
def __init__(self):
112+
self.messages = ()
113+
self.codes = ()
114+
115+
def add_message_and_code(self, perm):
116+
message = getattr(perm, 'message', None)
117+
code = getattr(perm, 'code', None)
118+
self.messages += (message,)
119+
self.codes += (code,)
120+
121+
def finalize(self, perm):
122+
perm.message = self.messages
123+
perm.code = self.codes
117124

118125

119126
class NOT:

0 commit comments

Comments
 (0)