diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index 8fb4569cb1..10bbef87fc 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -8,6 +8,23 @@ SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS') +class PermissionCacheMixin: + def __init__(self): + self._cache = {} + + def has_permission_value(self, request, view): + key = (request, view) + if key not in self._cache: + self._cache[key] = self.has_permission(request, view) + return self._cache[key] + + def has_object_permission_value(self, request, view, obj): + key = (request, view, obj) + if key not in self._cache: + self._cache[key] = self.has_object_permission(request, view, obj) + return self._cache[key] + + class OperationHolderMixin: def __and__(self, other): return OperandHolder(AND, self, other) @@ -55,61 +72,64 @@ def __eq__(self, other): ) -class AND: +class AND(PermissionCacheMixin): def __init__(self, op1, op2): + super().__init__() self.op1 = op1 self.op2 = op2 def has_permission(self, request, view): return ( - self.op1.has_permission(request, view) and - self.op2.has_permission(request, view) + self.op1.has_permission_value(request, view) and + self.op2.has_permission_value(request, view) ) def has_object_permission(self, request, view, obj): return ( - self.op1.has_object_permission(request, view, obj) and - self.op2.has_object_permission(request, view, obj) + self.op1.has_object_permission_value(request, view, obj) and + self.op2.has_object_permission_value(request, view, obj) ) -class OR: +class OR(PermissionCacheMixin): def __init__(self, op1, op2): + super().__init__() self.op1 = op1 self.op2 = op2 def has_permission(self, request, view): return ( - self.op1.has_permission(request, view) or - self.op2.has_permission(request, view) + self.op1.has_permission_value(request, view) or + self.op2.has_permission_value(request, view) ) def has_object_permission(self, request, view, obj): return ( - self.op1.has_permission(request, view) - and self.op1.has_object_permission(request, view, obj) + self.op1.has_permission_value(request, view) + and self.op1.has_object_permission_value(request, view, obj) ) or ( - self.op2.has_permission(request, view) - and self.op2.has_object_permission(request, view, obj) + self.op2.has_permission_value(request, view) + and self.op2.has_object_permission_value(request, view, obj) ) -class NOT: +class NOT(PermissionCacheMixin): def __init__(self, op1): + super().__init__() self.op1 = op1 def has_permission(self, request, view): - return not self.op1.has_permission(request, view) + return not self.op1.has_permission_value(request, view) def has_object_permission(self, request, view, obj): - return not self.op1.has_object_permission(request, view, obj) + return not self.op1.has_object_permission_value(request, view, obj) class BasePermissionMetaclass(OperationHolderMixin, type): pass -class BasePermission(metaclass=BasePermissionMetaclass): +class BasePermission(PermissionCacheMixin, metaclass=BasePermissionMetaclass): """ A base class from which all permission classes should inherit. """ diff --git a/rest_framework/views.py b/rest_framework/views.py index 4c30029fdc..ecce544196 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -1,6 +1,7 @@ """ Provides an APIView class that is the base of all views in REST framework. """ + from django.conf import settings from django.core.exceptions import PermissionDenied from django.db import connections, models @@ -8,6 +9,7 @@ from django.http.response import HttpResponseBase from django.utils.cache import cc_delim_re, patch_vary_headers from django.utils.encoding import smart_str +from django.utils.functional import cached_property from django.views.decorators.csrf import csrf_exempt from django.views.generic import View @@ -277,6 +279,10 @@ def get_permissions(self): """ return [permission() for permission in self.permission_classes] + @cached_property + def cached_permissions(self): + return self.get_permissions() + def get_throttles(self): """ Instantiates and returns the list of throttles that this view uses. @@ -328,8 +334,8 @@ def check_permissions(self, request): Check if the request should be permitted. Raises an appropriate exception if the request is not permitted. """ - for permission in self.get_permissions(): - if not permission.has_permission(request, self): + for permission in self.cached_permissions: + if not permission.has_permission_value(request, self): self.permission_denied( request, message=getattr(permission, 'message', None), @@ -341,8 +347,8 @@ def check_object_permissions(self, request, obj): Check if the request should be permitted for a given object. Raises an appropriate exception if the request is not permitted. """ - for permission in self.get_permissions(): - if not permission.has_object_permission(request, self, obj): + for permission in self.cached_permissions: + if not permission.has_object_permission_value(request, self, obj): self.permission_denied( request, message=getattr(permission, 'message', None), diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 428480dc7e..81a6d6c15d 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -735,3 +735,32 @@ def has_object_permission(self, request, view, obj): composed_perm = (IsAuthenticatedUserOwner | permissions.IsAdminUser) hasperm = composed_perm().has_object_permission(request, None, None) assert hasperm is False + + +class PermissionsCacheTests(TestCase): + + class IsAuthenticatedUserOwnerWithCounter(permissions.IsAuthenticated): + + def __init__(self): + super().__init__() + self.call_counter = 0 + + def has_permission(self, request, view): + self.call_counter += 1 + return True + + def test_composed_perm_permissions(self): + request = factory.get('/1', format='json') + request.user = AnonymousUser() + + composed_perm = (self.IsAuthenticatedUserOwnerWithCounter | permissions.IsAdminUser) + composed_perm_instance = composed_perm() + # in OR composed permissions has_object_permission will call has_permission too. + # we must ensure that this method (has_permission) is called once + has_permission_value = composed_perm_instance.has_permission(request, None) + has_object_permission_value = composed_perm_instance.has_object_permission(request, None, None) + + self.assertTrue(has_permission_value) + self.assertTrue(has_object_permission_value) + + self.assertEqual(composed_perm_instance.op1.call_counter, 1)