Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 15 additions & 31 deletions openwisp_controller/connection/api/views.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
from django.core.exceptions import ValidationError
from django.utils.translation import gettext_lazy as _
from drf_yasg import openapi
from drf_yasg.utils import swagger_auto_schema
from rest_framework import pagination
from rest_framework.exceptions import NotFound
from rest_framework.generics import (
GenericAPIView,
ListCreateAPIView,
RetrieveAPIView,
RetrieveUpdateDestroyAPIView,
get_object_or_404,
)
from swapper import load_model

from openwisp_users.api.mixins import FilterByParentManaged
from openwisp_users.api.mixins import ProtectedAPIMixin as BaseProtectedAPIMixin

from ...mixins import (
ProtectedAPIMixin,
RelatedDeviceModelPermission,
Expand All @@ -39,10 +33,9 @@ class ListViewPagination(pagination.PageNumberPagination):
max_page_size = 100


class BaseCommandView(
BaseProtectedAPIMixin,
FilterByParentManaged,
):
class BaseCommandView(RelatedDeviceProtectedAPIMixin):
organization_field = "device__organization"
organization_lookup = "organization__in"
model = Command
queryset = Command.objects.prefetch_related("device")
serializer_class = CommandSerializer
Expand Down Expand Up @@ -116,44 +109,35 @@ class CredentialDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView):
serializer_class = CredentialSerializer


class BaseDeviceConnection(RelatedDeviceProtectedAPIMixin, GenericAPIView):
class BaseDeviceConnection(
RelatedDeviceProtectedAPIMixin,
):
organization_field = "device__organization"
organization_lookup = "organization__in"
model = DeviceConnection
serializer_class = DeviceConnectionSerializer
queryset = DeviceConnection.objects.prefetch_related("device")

def get_queryset(self):
return DeviceConnection.objects.prefetch_related("device")
return (
super()
.get_queryset()
.filter(device_id=self.kwargs["device_id"])
.order_by("-created")
)

def get_serializer_context(self):
context = super().get_serializer_context()
context["device_id"] = self.kwargs["device_id"]
return context

def initial(self, *args, **kwargs):
super().initial(*args, **kwargs)
self.assert_parent_exists()

def assert_parent_exists(self):
try:
assert self.get_parent_queryset().exists()
except (AssertionError, ValidationError):
device_id = self.kwargs["device_id"]
raise NotFound(detail=f'Device with ID "{device_id}" not found.')

def get_parent_queryset(self):
return Device.objects.filter(pk=self.kwargs["device_id"])


class DeviceConnenctionListCreateView(BaseDeviceConnection, ListCreateAPIView):
pagination_class = ListViewPagination

def get_queryset(self):
return (
super()
.get_queryset()
.filter(device_id=self.kwargs["device_id"])
.order_by("-created")
)


class DeviceConnectionDetailView(BaseDeviceConnection, RetrieveUpdateDestroyAPIView):
def get_object(self):
Expand Down
221 changes: 192 additions & 29 deletions openwisp_controller/connection/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,35 +315,72 @@ def test_endpoints_for_deactivated_device(self):
)
self.assertEqual(response.status_code, 200)

def test_non_superuser(self):
list_url = self._get_path("device_command_list", self.device_id)
command = self._create_command(device_conn=self.device_conn)
device = command.device
def _test_command_endpoints(
self,
list_path,
detail_path,
expected_status,
):
with self.subTest("List operation"):
response = self.client.get(list_path)
self.assertEqual(response.status_code, expected_status["list"])

with self.subTest("Create operation"):
response = self.client.post(
list_path,
data={"type": "custom", "input": {"command": "echo test"}},
content_type="application/json",
)
self.assertEqual(response.status_code, expected_status["create"])

with self.subTest("Test with unauthenticated user"):
self.client.logout()
response = self.client.get(list_url)
self.assertEqual(response.status_code, 401)
with self.subTest("Retrieve operation"):
response = self.client.get(detail_path)
self.assertEqual(response.status_code, expected_status["retrieve"])

with self.subTest("Test with organization member"):
org_user = self._create_org_user(is_admin=True)
org_user.user.groups.add(Group.objects.get(name="Operator"))
self.client.force_login(org_user.user)
self.assertEqual(device.organization, org_user.organization)
def test_endpoints_for_org_operators_own_org(self):
self.client.logout()
operator = self._create_operator(organizations=[self._get_org()])
self.client.force_login(operator)
list_path = self._get_path("device_command_list", self.device_id)
command = self._create_command(device_conn=self.device_conn)
detail_path = self._get_path(
"device_command_details", self.device_id, command.id
)
self._test_command_endpoints(
list_path,
detail_path,
expected_status={"list": 200, "create": 201, "retrieve": 200},
)

response = self.client.get(list_url)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["count"], 1)
def test_endpoints_for_org_operator_different_org(self):
org2 = self._create_org(name="org2", slug="org2")
org2_admin = self._create_operator(organizations=[org2])
org1_command = self._create_command(device_conn=self.device_conn)
list_path = self._get_path("device_command_list", self.device_id)
detail_path = self._get_path(
"device_command_details", self.device_id, org1_command.id
)

with self.subTest("Test with org member of different org"):
org2 = self._create_org(name="org2", slug="org2")
org2_user = self._create_user(username="org2user", email="[email protected]")
self._create_org_user(organization=org2, user=org2_user, is_admin=True)
self.client.force_login(org2_user)
org2_user.groups.add(Group.objects.get(name="Operator"))
self.client.logout()
self.client.force_login(org2_admin)
self._test_command_endpoints(
list_path,
detail_path,
expected_status={"list": 404, "create": 404, "retrieve": 404},
)

response = self.client.get(list_url)
self.assertEqual(response.status_code, 404)
def test_unauthenticated_user(self):
list_path = self._get_path("device_command_list", self.device_id)
command = self._create_command(device_conn=self.device_conn)
self.client.logout()
detail_path = self._get_path(
"device_command_details", self.device_id, command.id
)
self._test_command_endpoints(
list_path,
detail_path,
expected_status={"list": 401, "create": 401, "retrieve": 401},
)

def test_non_existent_command(self):
url = self._get_path("device_command_list", self.device_id)
Expand Down Expand Up @@ -497,7 +534,7 @@ def test_delete_credential_detail(self):
def test_get_deviceconnection_list(self):
d1 = self._create_device()
path = reverse("connection_api:deviceconnection_list", args=(d1.pk,))
with self.assertNumQueries(3):
with self.assertNumQueries(4):
response = self.client.get(path)
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data["count"], 0)
Expand Down Expand Up @@ -554,7 +591,7 @@ def test_get_deviceconnection_detail(self):
dc = self._create_device_connection()
d1 = dc.device.id
path = reverse("connection_api:deviceconnection_detail", args=(d1, dc.pk))
with self.assertNumQueries(4):
with self.assertNumQueries(5):
response = self.client.get(path)
self.assertEqual(response.status_code, 200)

Expand All @@ -569,7 +606,7 @@ def test_put_devceconnection_detail(self):
"enabled": False,
"failure_reason": "",
}
with self.assertNumQueries(13):
with self.assertNumQueries(14):
response = self.client.put(path, data, content_type="application/json")
self.assertEqual(response.status_code, 200)
self.assertEqual(
Expand All @@ -583,7 +620,7 @@ def test_patch_deviceconnectoin_detail(self):
path = reverse("connection_api:deviceconnection_detail", args=(d1, dc.pk))
self.assertEqual(dc.update_strategy, app_settings.UPDATE_STRATEGIES[0][0])
data = {"update_strategy": app_settings.UPDATE_STRATEGIES[1][0]}
with self.assertNumQueries(12):
with self.assertNumQueries(13):
response = self.client.patch(path, data, content_type="application/json")
self.assertEqual(response.status_code, 200)
self.assertEqual(
Expand All @@ -594,7 +631,7 @@ def test_delete_deviceconnection_detail(self):
dc = self._create_device_connection()
d1 = dc.device.id
path = reverse("connection_api:deviceconnection_detail", args=(d1, dc.pk))
with self.assertNumQueries(9):
with self.assertNumQueries(10):
response = self.client.delete(path)
self.assertEqual(response.status_code, 204)

Expand Down Expand Up @@ -697,3 +734,129 @@ def test_deactivated_device(self):
detail_api_path,
)
self.assertEqual(response.status_code, 403)

def _test_deviceconnection_endpoints(
self,
device_id,
list_path,
detail_path,
expected_status,
):
with self.subTest("List operation"):
response = self.client.get(list_path)
self.assertEqual(response.status_code, expected_status["list"])

with self.subTest("Create operation"):
response = self.client.post(
list_path,
data={
"credentials": self._get_credentials(name="New Credentials").pk,
"update_strategy": app_settings.UPDATE_STRATEGIES[0][0],
"enabled": True,
"failure_reason": "",
},
content_type="application/json",
)

self.assertEqual(response.status_code, expected_status["create"])

with self.subTest("Retrieve operation"):
response = self.client.get(detail_path)
self.assertEqual(response.status_code, expected_status["retrieve"])

with self.subTest("Update operation"):
response = self.client.put(
detail_path,
{
"credentials": self._get_credentials().pk,
"update_strategy": app_settings.UPDATE_STRATEGIES[1][0],
"enabled": False,
"failure_reason": "",
},
content_type="application/json",
)
self.assertEqual(response.status_code, expected_status["update"])

with self.subTest("Partial update operation"):
response = self.client.patch(
detail_path, {"enabled": False}, content_type="application/json"
)
self.assertEqual(response.status_code, expected_status["patch"])

with self.subTest("Delete operation"):
response = self.client.delete(detail_path)
self.assertEqual(response.status_code, expected_status["delete"])

def test_deviceconnection_endpoints_for_org_operators_own_org(self):
self.client.logout()
operator = self._create_operator(organizations=[self._get_org()])
self.client.force_login(operator)
device = self._create_device()
self._create_config(device=device)
dc = self._create_device_connection(device=device)
list_path = reverse("connection_api:deviceconnection_list", args=(device.pk,))
detail_path = reverse(
"connection_api:deviceconnection_detail", args=(device.pk, dc.pk)
)
self._test_deviceconnection_endpoints(
device.pk,
list_path,
detail_path,
expected_status={
"list": 200,
"create": 201,
"retrieve": 200,
"update": 200,
"patch": 200,
"delete": 204,
},
)

def test_deviceconnection_endpoints_for_org_operator_different_org(self):
org2 = self._create_org(name="org2", slug="org2")
org2_operator = self._create_operator(organizations=[org2])
device = self._create_device()
self._create_config(device=device)
dc = self._create_device_connection(device=device)
list_path = reverse("connection_api:deviceconnection_list", args=(device.pk,))
detail_path = reverse(
"connection_api:deviceconnection_detail", args=(device.pk, dc.pk)
)
self.client.logout()
self.client.force_login(org2_operator)
self._test_deviceconnection_endpoints(
device.pk,
list_path,
detail_path,
expected_status={
"list": 404,
"create": 404,
"retrieve": 404,
"update": 404,
"patch": 404,
"delete": 404,
},
)

def test_deviceconnection_unauthenticated_user(self):
device = self._create_device()
self._create_config(device=device)
dc = self._create_device_connection(device=device)
list_path = reverse("connection_api:deviceconnection_list", args=(device.pk,))
detail_path = reverse(
"connection_api:deviceconnection_detail", args=(device.pk, dc.pk)
)
self.client.logout()
self._test_deviceconnection_endpoints(
device.pk,
list_path,
detail_path,
expected_status={
"list": 401,
"create": 401,
"retrieve": 401,
"update": 401,
"patch": 401,
"delete": 401,
},
)
6 changes: 2 additions & 4 deletions openwisp_controller/mixins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from openwisp_users.api.mixins import FilterByOrganizationManaged
from openwisp_users.api.mixins import FilterByOrganizationManaged, FilterByParentManaged
from openwisp_users.api.mixins import ProtectedAPIMixin as BaseProtectedAPIMixin
from openwisp_users.api.permissions import DjangoModelPermissions, IsOrganizationManager

Expand All @@ -24,9 +24,7 @@ def has_object_permission(self, request, view, obj):
return self._has_permissions(request, view, perm, obj)


class RelatedDeviceProtectedAPIMixin(
BaseProtectedAPIMixin, FilterByOrganizationManaged
):
class RelatedDeviceProtectedAPIMixin(FilterByParentManaged, BaseProtectedAPIMixin):
permission_classes = [
IsOrganizationManager,
RelatedDeviceModelPermission,
Expand Down
Loading