diff --git a/openwisp_controller/connection/api/views.py b/openwisp_controller/connection/api/views.py index 5eb8bf8e9..577885d04 100644 --- a/openwisp_controller/connection/api/views.py +++ b/openwisp_controller/connection/api/views.py @@ -1,11 +1,8 @@ -from django.core.exceptions import ValidationError from django.utils.translation import gettext_lazy as _ from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema from rest_framework import pagination -from rest_framework.exceptions import NotFound from rest_framework.generics import ( - GenericAPIView, ListCreateAPIView, RetrieveAPIView, RetrieveUpdateDestroyAPIView, @@ -13,9 +10,6 @@ ) from swapper import load_model -from openwisp_users.api.mixins import FilterByParentManaged -from openwisp_users.api.mixins import ProtectedAPIMixin as BaseProtectedAPIMixin - from ...mixins import ( ProtectedAPIMixin, RelatedDeviceModelPermission, @@ -39,10 +33,9 @@ class ListViewPagination(pagination.PageNumberPagination): max_page_size = 100 -class BaseCommandView( - BaseProtectedAPIMixin, - FilterByParentManaged, -): +class BaseCommandView(RelatedDeviceProtectedAPIMixin): + organization_field = "device__organization" + organization_lookup = "organization__in" model = Command queryset = Command.objects.prefetch_related("device") serializer_class = CommandSerializer @@ -116,29 +109,28 @@ class CredentialDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): serializer_class = CredentialSerializer -class BaseDeviceConnection(RelatedDeviceProtectedAPIMixin, GenericAPIView): +class BaseDeviceConnection( + RelatedDeviceProtectedAPIMixin, +): + organization_field = "device__organization" + organization_lookup = "organization__in" model = DeviceConnection serializer_class = DeviceConnectionSerializer + queryset = DeviceConnection.objects.prefetch_related("device") def get_queryset(self): - return DeviceConnection.objects.prefetch_related("device") + return ( + super() + .get_queryset() + .filter(device_id=self.kwargs["device_id"]) + .order_by("-created") + ) def get_serializer_context(self): context = super().get_serializer_context() context["device_id"] = self.kwargs["device_id"] return context - def initial(self, *args, **kwargs): - super().initial(*args, **kwargs) - self.assert_parent_exists() - - def assert_parent_exists(self): - try: - assert self.get_parent_queryset().exists() - except (AssertionError, ValidationError): - device_id = self.kwargs["device_id"] - raise NotFound(detail=f'Device with ID "{device_id}" not found.') - def get_parent_queryset(self): return Device.objects.filter(pk=self.kwargs["device_id"]) @@ -146,14 +138,6 @@ def get_parent_queryset(self): class DeviceConnenctionListCreateView(BaseDeviceConnection, ListCreateAPIView): pagination_class = ListViewPagination - def get_queryset(self): - return ( - super() - .get_queryset() - .filter(device_id=self.kwargs["device_id"]) - .order_by("-created") - ) - class DeviceConnectionDetailView(BaseDeviceConnection, RetrieveUpdateDestroyAPIView): def get_object(self): diff --git a/openwisp_controller/connection/tests/test_api.py b/openwisp_controller/connection/tests/test_api.py index c4448ec87..714ee95ed 100644 --- a/openwisp_controller/connection/tests/test_api.py +++ b/openwisp_controller/connection/tests/test_api.py @@ -315,35 +315,72 @@ def test_endpoints_for_deactivated_device(self): ) self.assertEqual(response.status_code, 200) - def test_non_superuser(self): - list_url = self._get_path("device_command_list", self.device_id) - command = self._create_command(device_conn=self.device_conn) - device = command.device + def _test_command_endpoints( + self, + list_path, + detail_path, + expected_status, + ): + with self.subTest("List operation"): + response = self.client.get(list_path) + self.assertEqual(response.status_code, expected_status["list"]) + + with self.subTest("Create operation"): + response = self.client.post( + list_path, + data={"type": "custom", "input": {"command": "echo test"}}, + content_type="application/json", + ) + self.assertEqual(response.status_code, expected_status["create"]) - with self.subTest("Test with unauthenticated user"): - self.client.logout() - response = self.client.get(list_url) - self.assertEqual(response.status_code, 401) + with self.subTest("Retrieve operation"): + response = self.client.get(detail_path) + self.assertEqual(response.status_code, expected_status["retrieve"]) - with self.subTest("Test with organization member"): - org_user = self._create_org_user(is_admin=True) - org_user.user.groups.add(Group.objects.get(name="Operator")) - self.client.force_login(org_user.user) - self.assertEqual(device.organization, org_user.organization) + def test_endpoints_for_org_operators_own_org(self): + self.client.logout() + operator = self._create_operator(organizations=[self._get_org()]) + self.client.force_login(operator) + list_path = self._get_path("device_command_list", self.device_id) + command = self._create_command(device_conn=self.device_conn) + detail_path = self._get_path( + "device_command_details", self.device_id, command.id + ) + self._test_command_endpoints( + list_path, + detail_path, + expected_status={"list": 200, "create": 201, "retrieve": 200}, + ) - response = self.client.get(list_url) - self.assertEqual(response.status_code, 200) - self.assertEqual(response.data["count"], 1) + def test_endpoints_for_org_operator_different_org(self): + org2 = self._create_org(name="org2", slug="org2") + org2_admin = self._create_operator(organizations=[org2]) + org1_command = self._create_command(device_conn=self.device_conn) + list_path = self._get_path("device_command_list", self.device_id) + detail_path = self._get_path( + "device_command_details", self.device_id, org1_command.id + ) - with self.subTest("Test with org member of different org"): - org2 = self._create_org(name="org2", slug="org2") - org2_user = self._create_user(username="org2user", email="user@org2.com") - self._create_org_user(organization=org2, user=org2_user, is_admin=True) - self.client.force_login(org2_user) - org2_user.groups.add(Group.objects.get(name="Operator")) + self.client.logout() + self.client.force_login(org2_admin) + self._test_command_endpoints( + list_path, + detail_path, + expected_status={"list": 404, "create": 404, "retrieve": 404}, + ) - response = self.client.get(list_url) - self.assertEqual(response.status_code, 404) + def test_unauthenticated_user(self): + list_path = self._get_path("device_command_list", self.device_id) + command = self._create_command(device_conn=self.device_conn) + self.client.logout() + detail_path = self._get_path( + "device_command_details", self.device_id, command.id + ) + self._test_command_endpoints( + list_path, + detail_path, + expected_status={"list": 401, "create": 401, "retrieve": 401}, + ) def test_non_existent_command(self): url = self._get_path("device_command_list", self.device_id) @@ -497,7 +534,7 @@ def test_delete_credential_detail(self): def test_get_deviceconnection_list(self): d1 = self._create_device() path = reverse("connection_api:deviceconnection_list", args=(d1.pk,)) - with self.assertNumQueries(3): + with self.assertNumQueries(4): response = self.client.get(path) self.assertEqual(response.status_code, 200) self.assertEqual(response.data["count"], 0) @@ -554,7 +591,7 @@ def test_get_deviceconnection_detail(self): dc = self._create_device_connection() d1 = dc.device.id path = reverse("connection_api:deviceconnection_detail", args=(d1, dc.pk)) - with self.assertNumQueries(4): + with self.assertNumQueries(5): response = self.client.get(path) self.assertEqual(response.status_code, 200) @@ -569,7 +606,7 @@ def test_put_devceconnection_detail(self): "enabled": False, "failure_reason": "", } - with self.assertNumQueries(13): + with self.assertNumQueries(14): response = self.client.put(path, data, content_type="application/json") self.assertEqual(response.status_code, 200) self.assertEqual( @@ -583,7 +620,7 @@ def test_patch_deviceconnectoin_detail(self): path = reverse("connection_api:deviceconnection_detail", args=(d1, dc.pk)) self.assertEqual(dc.update_strategy, app_settings.UPDATE_STRATEGIES[0][0]) data = {"update_strategy": app_settings.UPDATE_STRATEGIES[1][0]} - with self.assertNumQueries(12): + with self.assertNumQueries(13): response = self.client.patch(path, data, content_type="application/json") self.assertEqual(response.status_code, 200) self.assertEqual( @@ -594,7 +631,7 @@ def test_delete_deviceconnection_detail(self): dc = self._create_device_connection() d1 = dc.device.id path = reverse("connection_api:deviceconnection_detail", args=(d1, dc.pk)) - with self.assertNumQueries(9): + with self.assertNumQueries(10): response = self.client.delete(path) self.assertEqual(response.status_code, 204) @@ -697,3 +734,129 @@ def test_deactivated_device(self): detail_api_path, ) self.assertEqual(response.status_code, 403) + + def _test_deviceconnection_endpoints( + self, + device_id, + list_path, + detail_path, + expected_status, + ): + with self.subTest("List operation"): + response = self.client.get(list_path) + self.assertEqual(response.status_code, expected_status["list"]) + + with self.subTest("Create operation"): + response = self.client.post( + list_path, + data={ + "credentials": self._get_credentials(name="New Credentials").pk, + "update_strategy": app_settings.UPDATE_STRATEGIES[0][0], + "enabled": True, + "failure_reason": "", + }, + content_type="application/json", + ) + + self.assertEqual(response.status_code, expected_status["create"]) + + with self.subTest("Retrieve operation"): + response = self.client.get(detail_path) + self.assertEqual(response.status_code, expected_status["retrieve"]) + + with self.subTest("Update operation"): + response = self.client.put( + detail_path, + { + "credentials": self._get_credentials().pk, + "update_strategy": app_settings.UPDATE_STRATEGIES[1][0], + "enabled": False, + "failure_reason": "", + }, + content_type="application/json", + ) + self.assertEqual(response.status_code, expected_status["update"]) + + with self.subTest("Partial update operation"): + response = self.client.patch( + detail_path, {"enabled": False}, content_type="application/json" + ) + self.assertEqual(response.status_code, expected_status["patch"]) + + with self.subTest("Delete operation"): + response = self.client.delete(detail_path) + self.assertEqual(response.status_code, expected_status["delete"]) + + def test_deviceconnection_endpoints_for_org_operators_own_org(self): + self.client.logout() + operator = self._create_operator(organizations=[self._get_org()]) + self.client.force_login(operator) + device = self._create_device() + self._create_config(device=device) + dc = self._create_device_connection(device=device) + list_path = reverse("connection_api:deviceconnection_list", args=(device.pk,)) + detail_path = reverse( + "connection_api:deviceconnection_detail", args=(device.pk, dc.pk) + ) + self._test_deviceconnection_endpoints( + device.pk, + list_path, + detail_path, + expected_status={ + "list": 200, + "create": 201, + "retrieve": 200, + "update": 200, + "patch": 200, + "delete": 204, + }, + ) + + def test_deviceconnection_endpoints_for_org_operator_different_org(self): + org2 = self._create_org(name="org2", slug="org2") + org2_operator = self._create_operator(organizations=[org2]) + device = self._create_device() + self._create_config(device=device) + dc = self._create_device_connection(device=device) + list_path = reverse("connection_api:deviceconnection_list", args=(device.pk,)) + detail_path = reverse( + "connection_api:deviceconnection_detail", args=(device.pk, dc.pk) + ) + self.client.logout() + self.client.force_login(org2_operator) + self._test_deviceconnection_endpoints( + device.pk, + list_path, + detail_path, + expected_status={ + "list": 404, + "create": 404, + "retrieve": 404, + "update": 404, + "patch": 404, + "delete": 404, + }, + ) + + def test_deviceconnection_unauthenticated_user(self): + device = self._create_device() + self._create_config(device=device) + dc = self._create_device_connection(device=device) + list_path = reverse("connection_api:deviceconnection_list", args=(device.pk,)) + detail_path = reverse( + "connection_api:deviceconnection_detail", args=(device.pk, dc.pk) + ) + self.client.logout() + self._test_deviceconnection_endpoints( + device.pk, + list_path, + detail_path, + expected_status={ + "list": 401, + "create": 401, + "retrieve": 401, + "update": 401, + "patch": 401, + "delete": 401, + }, + ) diff --git a/openwisp_controller/mixins.py b/openwisp_controller/mixins.py index db1378d92..2e148772a 100644 --- a/openwisp_controller/mixins.py +++ b/openwisp_controller/mixins.py @@ -1,4 +1,4 @@ -from openwisp_users.api.mixins import FilterByOrganizationManaged +from openwisp_users.api.mixins import FilterByOrganizationManaged, FilterByParentManaged from openwisp_users.api.mixins import ProtectedAPIMixin as BaseProtectedAPIMixin from openwisp_users.api.permissions import DjangoModelPermissions, IsOrganizationManager @@ -24,9 +24,7 @@ def has_object_permission(self, request, view, obj): return self._has_permissions(request, view, perm, obj) -class RelatedDeviceProtectedAPIMixin( - BaseProtectedAPIMixin, FilterByOrganizationManaged -): +class RelatedDeviceProtectedAPIMixin(FilterByParentManaged, BaseProtectedAPIMixin): permission_classes = [ IsOrganizationManager, RelatedDeviceModelPermission,