Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 17 additions & 24 deletions openwisp_users/api/views.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from allauth.account.models import EmailAddress
from django.contrib.auth import get_user_model
from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _
from drf_yasg.utils import swagger_auto_schema
from rest_framework import pagination
from rest_framework.authtoken.views import ObtainAuthToken
from rest_framework.exceptions import NotFound
from rest_framework.generics import (
GenericAPIView,
ListCreateAPIView,
Expand All @@ -20,6 +18,7 @@

from openwisp_users.api.permissions import DjangoModelPermissions

from .mixins import FilterByParent
from .mixins import ProtectedAPIMixin as BaseProtectedAPIMixin
from .serializers import (
ChangePasswordSerializer,
Expand Down Expand Up @@ -198,7 +197,7 @@ def update(self, request, *args, **kwargs):
)


class BaseEmailView(ProtectedAPIMixin, GenericAPIView):
class BaseEmailView(ProtectedAPIMixin, FilterByParent, GenericAPIView):
model = EmailAddress
serializer_class = EmailAddressSerializer

Expand All @@ -209,28 +208,22 @@ def initial(self, *args, **kwargs):
super().initial(*args, **kwargs)
self.assert_parent_exists()

def assert_parent_exists(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is shipped by FilterByParent so we can remove it.

try:
assert self.get_parent_queryset().exists()
except (AssertionError, ValidationError):
user_id = self.kwargs['pk']
raise NotFound(detail=_("User with ID '{}' not found.".format(user_id)))

def get_parent_queryset(self):
user = self.request.user

if user.is_superuser:
return User.objects.filter(pk=self.kwargs['pk'])

org_users = OrganizationUser.objects.filter(user=user).select_related(
'organization'
)
qs_user = User.objects.none()
for org_user in org_users:
if org_user.is_admin:
qs_user = qs_user | org_user.organization.users.all().distinct()
qs_user = qs_user.filter(is_superuser=False)
return qs_user.filter(pk=self.kwargs['pk'])
qs = User.objects.filter(pk=self.kwargs['pk'])
if self.request.user.is_superuser:
return qs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the user performing the request is superuser, just return the parent without further checks (superusers can do anything).

return self.get_organization_queryset(qs)

def get_organization_queryset(self, qs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal of this method, if I am not mistaken, is to ensure that the parent object is related to one of the organizations the user performing the API request manages, otherwise the API shall return 404 because nothing is found (the query doens't return any result).

orgs = self.request.user.organizations_managed
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Therefore we use this handy method to get the list of organization IDs the user manages.

app_label = User._meta.app_config.label
filter_kwargs = {
# exclude superusers
'is_superuser': False,
# ensure user is member of the org
f'{app_label}_organizationuser__organization_id__in': orgs,
}
return qs.filter(**filter_kwargs).distinct()

def get_serializer_context(self):
if getattr(self, 'swagger_fake_view', False):
Expand Down
2 changes: 1 addition & 1 deletion openwisp_users/tests/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def test_get_email_list_multitenancy_api(self):
self._create_org_user(user=org2_user, organization=org2)
self.client.force_login(org1_user)
path = reverse('users:email_list', args=(org2_user.pk,))
with self.assertNumQueries(5):
with self.assertNumQueries(4):
response = self.client.get(path)
self.assertEqual(response.status_code, 404)

Expand Down
Loading