diff --git a/openwisp_controller/config/api/serializers.py b/openwisp_controller/config/api/serializers.py index 7e67d8f5f..de0d252bd 100644 --- a/openwisp_controller/config/api/serializers.py +++ b/openwisp_controller/config/api/serializers.py @@ -3,6 +3,7 @@ from django.db.models import Q from django.utils.translation import gettext_lazy as _ from rest_framework import serializers +from reversion.models import Version from swapper import load_model from openwisp_utils.api.serializers import ValidatedModelSerializer @@ -376,3 +377,28 @@ def update(self, instance, validated_data): instance = super().update(instance, validated_data) self._save_m2m_templates(instance) return instance + + +class VersionSerializer(BaseSerializer): + user_id = serializers.CharField(source="revision.user_id", read_only=True) + date_created = serializers.DateTimeField( + source="revision.date_created", read_only=True + ) + comment = serializers.CharField(source="revision.comment", read_only=True) + content_type = serializers.CharField(source="revision.content_type", read_only=True) + + class Meta: + model = Version + fields = [ + "id", + "revision_id", + "object_id", + "content_type", + "db", + "format", + "serialized_data", + "object_repr", + "user_id", + "date_created", + "comment", + ] diff --git a/openwisp_controller/config/api/urls.py b/openwisp_controller/config/api/urls.py index 9936ad213..3f545bd76 100644 --- a/openwisp_controller/config/api/urls.py +++ b/openwisp_controller/config/api/urls.py @@ -13,6 +13,21 @@ def get_api_urls(api_views): """ if getattr(settings, "OPENWISP_CONTROLLER_API", True): return [ + path( + "controller//revision/", + api_views.revision_list, + name="revision_list", + ), + path( + "controller//revision//", + api_views.version_detail, + name="version_detail", + ), + path( + "controller//revision//restore/", + api_views.revision_restore, + name="revision_restore", + ), path( "controller/template/", api_views.template_list, diff --git a/openwisp_controller/config/api/views.py b/openwisp_controller/config/api/views.py index db77ce15c..fa6930948 100644 --- a/openwisp_controller/config/api/views.py +++ b/openwisp_controller/config/api/views.py @@ -1,22 +1,27 @@ +import reversion from cache_memoize import cache_memoize from django.core.exceptions import ObjectDoesNotExist +from django.db import transaction from django.db.models import F, Q from django.http import Http404 +from django.shortcuts import get_list_or_404 from django.urls.base import reverse from django_filters.rest_framework import DjangoFilterBackend from rest_framework import pagination, serializers, status from rest_framework.generics import ( GenericAPIView, + ListAPIView, ListCreateAPIView, RetrieveAPIView, RetrieveUpdateDestroyAPIView, ) from rest_framework.response import Response +from reversion.models import Version from swapper import load_model from openwisp_users.api.permissions import DjangoModelPermissions -from ...mixins import ProtectedAPIMixin +from ...mixins import AutoRevisionMixin, ProtectedAPIMixin from .filters import ( DeviceGroupListFilter, DeviceListFilter, @@ -29,6 +34,7 @@ DeviceGroupSerializer, DeviceListSerializer, TemplateSerializer, + VersionSerializer, VpnSerializer, ) @@ -48,7 +54,7 @@ class ListViewPagination(pagination.PageNumberPagination): max_page_size = 100 -class TemplateListCreateView(ProtectedAPIMixin, ListCreateAPIView): +class TemplateListCreateView(ProtectedAPIMixin, AutoRevisionMixin, ListCreateAPIView): serializer_class = TemplateSerializer queryset = Template.objects.prefetch_related("tags").order_by("-created") pagination_class = ListViewPagination @@ -56,12 +62,14 @@ class TemplateListCreateView(ProtectedAPIMixin, ListCreateAPIView): filterset_class = TemplateListFilter -class TemplateDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): +class TemplateDetailView( + ProtectedAPIMixin, AutoRevisionMixin, RetrieveUpdateDestroyAPIView +): serializer_class = TemplateSerializer queryset = Template.objects.all() -class VpnListCreateView(ProtectedAPIMixin, ListCreateAPIView): +class VpnListCreateView(ProtectedAPIMixin, AutoRevisionMixin, ListCreateAPIView): serializer_class = VpnSerializer queryset = Vpn.objects.select_related("subnet").order_by("-created") pagination_class = ListViewPagination @@ -69,7 +77,7 @@ class VpnListCreateView(ProtectedAPIMixin, ListCreateAPIView): filterset_class = VPNListFilter -class VpnDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): +class VpnDetailView(ProtectedAPIMixin, AutoRevisionMixin, RetrieveUpdateDestroyAPIView): serializer_class = VpnSerializer queryset = Vpn.objects.all() @@ -82,7 +90,7 @@ def has_object_permission(self, request, view, obj): return perm and not obj.is_deactivated() -class DeviceListCreateView(ProtectedAPIMixin, ListCreateAPIView): +class DeviceListCreateView(ProtectedAPIMixin, AutoRevisionMixin, ListCreateAPIView): """ Templates: Templates flagged as required will be added automatically to the `config` of a device and cannot be unassigned. @@ -97,7 +105,9 @@ class DeviceListCreateView(ProtectedAPIMixin, ListCreateAPIView): filterset_class = DeviceListFilter -class DeviceDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): +class DeviceDetailView( + ProtectedAPIMixin, AutoRevisionMixin, RetrieveUpdateDestroyAPIView +): """ Templates: Templates flagged as _required_ will be added automatically to the `config` of a device and cannot be unassigned. @@ -124,7 +134,7 @@ def get_serializer_context(self): return context -class DeviceActivateView(ProtectedAPIMixin, GenericAPIView): +class DeviceActivateView(ProtectedAPIMixin, AutoRevisionMixin, GenericAPIView): serializer_class = serializers.Serializer queryset = Device.objects.filter(_is_deactivated=True) @@ -137,7 +147,7 @@ def post(self, request, *args, **kwargs): return Response(serializer.data, status=status.HTTP_200_OK) -class DeviceDeactivateView(ProtectedAPIMixin, GenericAPIView): +class DeviceDeactivateView(ProtectedAPIMixin, AutoRevisionMixin, GenericAPIView): serializer_class = serializers.Serializer queryset = Device.objects.filter(_is_deactivated=False) @@ -150,7 +160,9 @@ def post(self, request, *args, **kwargs): return Response(serializer.data, status=status.HTTP_200_OK) -class DeviceGroupListCreateView(ProtectedAPIMixin, ListCreateAPIView): +class DeviceGroupListCreateView( + ProtectedAPIMixin, AutoRevisionMixin, ListCreateAPIView +): serializer_class = DeviceGroupSerializer queryset = DeviceGroup.objects.prefetch_related("templates").order_by("-created") pagination_class = ListViewPagination @@ -158,7 +170,9 @@ class DeviceGroupListCreateView(ProtectedAPIMixin, ListCreateAPIView): filterset_class = DeviceGroupListFilter -class DeviceGroupDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): +class DeviceGroupDetailView( + ProtectedAPIMixin, AutoRevisionMixin, RetrieveUpdateDestroyAPIView +): serializer_class = DeviceGroupSerializer queryset = DeviceGroup.objects.select_related("organization").order_by("-created") @@ -172,7 +186,7 @@ def get_cached_devicegroup_args_rewrite(cls, org_slugs, common_name): return url -class DeviceGroupCommonName(ProtectedAPIMixin, RetrieveAPIView): +class DeviceGroupCommonName(ProtectedAPIMixin, AutoRevisionMixin, RetrieveAPIView): serializer_class = DeviceGroupSerializer queryset = DeviceGroup.objects.select_related("organization").order_by("-created") # Not setting lookup_field makes DRF raise error. but it is not used @@ -289,6 +303,60 @@ def certificate_delete_invalidates_cache(cls, organization_id, common_name): cls.get_device_group.invalidate(cls, org_slug, common_name) +class RevisionListView(ProtectedAPIMixin, ListAPIView): + serializer_class = VersionSerializer + queryset = Version.objects.select_related("revision").order_by( + "-revision__date_created" + ) + + def get_queryset(self): + model = self.kwargs.get("model").lower() + queryset = self.queryset.filter(content_type__model=model) + revision_id = self.request.query_params.get("revision_id") + if revision_id: + queryset = queryset.filter(revision_id=revision_id) + return self.queryset.filter(content_type__model=model) + + +class VersionDetailView(ProtectedAPIMixin, RetrieveAPIView): + serializer_class = VersionSerializer + queryset = Version.objects.select_related("revision").order_by( + "-revision__date_created" + ) + + def get_queryset(self): + model = self.kwargs.get("model").lower() + return self.queryset.filter(content_type__model=model) + + +class RevisionRestoreView(ProtectedAPIMixin, GenericAPIView): + serializer_class = serializers.Serializer + queryset = Version.objects.select_related("revision").order_by( + "-revision__date_created" + ) + + def get_queryset(self): + model = self.kwargs.get("model").lower() + return self.queryset.filter(content_type__model=model) + + def post(self, request, *args, **kwargs): + qs = self.get_queryset() + versions = get_list_or_404(qs, revision_id=kwargs["pk"]) + with transaction.atomic(): + with reversion.create_revision(): + for version in versions: + version.revert() + reversion.set_user(request.user) + reversion.set_comment( + f"Restored to previous revision: {self.kwargs.get('pk')}" + ) + + serializer = VersionSerializer( + versions, many=True, context=self.get_serializer_context() + ) + return Response(serializer.data, status=status.HTTP_200_OK) + + template_list = TemplateListCreateView.as_view() template_detail = TemplateDetailView.as_view() vpn_list = VpnListCreateView.as_view() @@ -300,3 +368,6 @@ def certificate_delete_invalidates_cache(cls, organization_id, common_name): devicegroup_list = DeviceGroupListCreateView.as_view() devicegroup_detail = DeviceGroupDetailView.as_view() devicegroup_commonname = DeviceGroupCommonName.as_view() +revision_list = RevisionListView.as_view() +version_detail = VersionDetailView.as_view() +revision_restore = RevisionRestoreView.as_view() diff --git a/openwisp_controller/config/tests/test_api.py b/openwisp_controller/config/tests/test_api.py index 880c142d1..8b1614d95 100644 --- a/openwisp_controller/config/tests/test_api.py +++ b/openwisp_controller/config/tests/test_api.py @@ -1,3 +1,4 @@ +import reversion from django.contrib.auth.models import Permission from django.test import TestCase from django.test.client import BOUNDARY, MULTIPART_CONTENT, encode_multipart @@ -1634,3 +1635,49 @@ def test_device_patch_with_templates_of_same_org(self): self.assertEqual(r.status_code, 200) self.assertEqual(d1.config.templates.count(), 2) self.assertEqual(r.data["config"]["templates"], [t1.id, t2.id]) + + def test_revision_list_and_restore_api(self): + org = self._get_org() + model_slug = "device" + with reversion.create_revision(): + device = self._create_device( + organization=org, + name="test", + ) + path = reverse("config_api:device_detail", args=[device.pk]) + data = dict(name="change-test-device") + response = self.client.patch(path, data, content_type="application/json") + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data["name"], "change-test-device") + + with self.subTest("Test revision list"): + path = reverse("config_api:revision_list", args=[model_slug]) + response = self.client.get(path) + response_json = response.json() + version_id = response_json[1]["id"] + revision_id = response_json[1]["revision_id"] + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response_json), 2) + + with self.subTest("Test revision list filter by revision id"): + path = reverse("config_api:revision_list", args=[model_slug]) + response = self.client.get(f"{path}?revision_id={revision_id}") + response_json = response.json() + self.assertEqual(response.status_code, 200) + self.assertEqual(len(response_json), 2) + + with self.subTest("Test version detail"): + path = reverse("config_api:version_detail", args=[model_slug, version_id]) + response = self.client.get(path) + self.assertEqual(response.status_code, 200) + self.assertEqual(response.json()["id"], version_id) + self.assertEqual(response.json()["object_id"], str(device.pk)) + + with self.subTest("Test revision restore view"): + revision_id = response_json[1]["revision_id"] + path = reverse( + "config_api:revision_restore", args=[model_slug, revision_id] + ) + response = self.client.post(path) + self.assertEqual(response.status_code, 200) + self.assertEqual(Device.objects.get(name="test").pk, device.pk) diff --git a/openwisp_controller/connection/api/views.py b/openwisp_controller/connection/api/views.py index 5ebc19947..cda3d544c 100644 --- a/openwisp_controller/connection/api/views.py +++ b/openwisp_controller/connection/api/views.py @@ -14,6 +14,7 @@ from openwisp_users.api.mixins import ProtectedAPIMixin as BaseProtectedAPIMixin from ...mixins import ( + AutoRevisionMixin, ProtectedAPIMixin, RelatedDeviceModelPermission, RelatedDeviceProtectedAPIMixin, @@ -61,7 +62,7 @@ def get_serializer_context(self): return context -class CommandListCreateView(BaseCommandView, ListCreateAPIView): +class CommandListCreateView(BaseCommandView, AutoRevisionMixin, ListCreateAPIView): pagination_class = ListViewPagination def create(self, request, *args, **kwargs): @@ -81,13 +82,15 @@ def get_object(self): return obj -class CredentialListCreateView(ProtectedAPIMixin, ListCreateAPIView): +class CredentialListCreateView(ProtectedAPIMixin, AutoRevisionMixin, ListCreateAPIView): queryset = Credentials.objects.order_by("-created") serializer_class = CredentialSerializer pagination_class = ListViewPagination -class CredentialDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): +class CredentialDetailView( + ProtectedAPIMixin, AutoRevisionMixin, RetrieveUpdateDestroyAPIView +): queryset = Credentials.objects.all() serializer_class = CredentialSerializer @@ -119,7 +122,9 @@ def get_parent_queryset(self): return Device.objects.filter(pk=self.kwargs["device_id"]) -class DeviceConnenctionListCreateView(BaseDeviceConnection, ListCreateAPIView): +class DeviceConnenctionListCreateView( + BaseDeviceConnection, AutoRevisionMixin, ListCreateAPIView +): pagination_class = ListViewPagination def get_queryset(self): @@ -131,7 +136,9 @@ def get_queryset(self): ) -class DeviceConnectionDetailView(BaseDeviceConnection, RetrieveUpdateDestroyAPIView): +class DeviceConnectionDetailView( + BaseDeviceConnection, AutoRevisionMixin, RetrieveUpdateDestroyAPIView +): def get_object(self): queryset = self.filter_queryset(self.get_queryset()) filter_kwargs = { diff --git a/openwisp_controller/connection/tests/test_api.py b/openwisp_controller/connection/tests/test_api.py index be13a7474..1a55402c5 100644 --- a/openwisp_controller/connection/tests/test_api.py +++ b/openwisp_controller/connection/tests/test_api.py @@ -494,7 +494,7 @@ def test_post_deviceconnection_list(self): "enabled": True, "failure_reason": "", } - with self.assertNumQueries(13): + with self.assertNumQueries(23): response = self.client.post(path, data, content_type="application/json") self.assertEqual(response.status_code, 201) @@ -539,7 +539,7 @@ def test_put_devceconnection_detail(self): "enabled": False, "failure_reason": "", } - with self.assertNumQueries(14): + with self.assertNumQueries(23): response = self.client.put(path, data, content_type="application/json") self.assertEqual(response.status_code, 200) self.assertEqual( @@ -553,7 +553,7 @@ def test_patch_deviceconnectoin_detail(self): path = reverse("connection_api:deviceconnection_detail", args=(d1, dc.pk)) self.assertEqual(dc.update_strategy, app_settings.UPDATE_STRATEGIES[0][0]) data = {"update_strategy": app_settings.UPDATE_STRATEGIES[1][0]} - with self.assertNumQueries(13): + with self.assertNumQueries(22): response = self.client.patch(path, data, content_type="application/json") self.assertEqual(response.status_code, 200) self.assertEqual( diff --git a/openwisp_controller/geo/api/views.py b/openwisp_controller/geo/api/views.py index b5514cf26..b25c5f641 100644 --- a/openwisp_controller/geo/api/views.py +++ b/openwisp_controller/geo/api/views.py @@ -14,7 +14,11 @@ from openwisp_users.api.filters import OrganizationManagedFilter from openwisp_users.api.mixins import FilterByOrganizationManaged, FilterByParentManaged -from ...mixins import ProtectedAPIMixin, RelatedDeviceProtectedAPIMixin +from ...mixins import ( + AutoRevisionMixin, + ProtectedAPIMixin, + RelatedDeviceProtectedAPIMixin, +) from .filters import DeviceListFilter from .serializers import ( DeviceCoordinatesSerializer, @@ -57,7 +61,9 @@ class ListViewPagination(pagination.PageNumberPagination): max_page_size = 100 -class DeviceCoordinatesView(ProtectedAPIMixin, generics.RetrieveUpdateAPIView): +class DeviceCoordinatesView( + ProtectedAPIMixin, AutoRevisionMixin, generics.RetrieveUpdateAPIView +): serializer_class = DeviceCoordinatesSerializer permission_classes = (DevicePermission,) queryset = Device.objects.select_related( @@ -105,6 +111,7 @@ def create_location(self, device): class DeviceLocationView( RelatedDeviceProtectedAPIMixin, + AutoRevisionMixin, FilterByParentManaged, generics.RetrieveUpdateDestroyAPIView, ): @@ -203,7 +210,9 @@ def get_queryset(self): return qs -class FloorPlanListCreateView(ProtectedAPIMixin, generics.ListCreateAPIView): +class FloorPlanListCreateView( + ProtectedAPIMixin, AutoRevisionMixin, generics.ListCreateAPIView +): serializer_class = FloorPlanSerializer queryset = FloorPlan.objects.select_related().order_by("-created") pagination_class = ListViewPagination @@ -213,13 +222,16 @@ class FloorPlanListCreateView(ProtectedAPIMixin, generics.ListCreateAPIView): class FloorPlanDetailView( ProtectedAPIMixin, + AutoRevisionMixin, generics.RetrieveUpdateDestroyAPIView, ): serializer_class = FloorPlanSerializer queryset = FloorPlan.objects.select_related() -class LocationListCreateView(ProtectedAPIMixin, generics.ListCreateAPIView): +class LocationListCreateView( + ProtectedAPIMixin, AutoRevisionMixin, generics.ListCreateAPIView +): serializer_class = LocationSerializer queryset = Location.objects.order_by("-created") pagination_class = ListViewPagination @@ -229,6 +241,7 @@ class LocationListCreateView(ProtectedAPIMixin, generics.ListCreateAPIView): class LocationDetailView( ProtectedAPIMixin, + AutoRevisionMixin, generics.RetrieveUpdateDestroyAPIView, ): serializer_class = LocationSerializer diff --git a/openwisp_controller/geo/tests/test_api.py b/openwisp_controller/geo/tests/test_api.py index b9cea2c98..1398d8720 100644 --- a/openwisp_controller/geo/tests/test_api.py +++ b/openwisp_controller/geo/tests/test_api.py @@ -3,6 +3,7 @@ import uuid from django.contrib.auth import get_user_model +from django.contrib.contenttypes.models import ContentType from django.contrib.gis.geos import Point from django.test import TestCase from django.test.client import BOUNDARY, MULTIPART_CONTENT, encode_multipart @@ -300,6 +301,7 @@ class TestGeoApi( def setUp(self): admin = self._create_admin() self.client.force_login(admin) + ContentType.objects.clear_cache() def _create_device_location(self, **kwargs): options = dict() @@ -494,7 +496,7 @@ def test_post_location_list(self): "address": "Via del Corso, Roma, Italia", "geometry": coords, } - with self.assertNumQueries(9): + with self.assertNumQueries(13): response = self.client.post(path, data, content_type="application/json") self.assertEqual(response.status_code, 201) @@ -525,7 +527,7 @@ def test_put_location_detail(self): "address": "Via del Corso, Roma, Italia", "geometry": coords, } - with self.assertNumQueries(6): + with self.assertNumQueries(10): response = self.client.put(path, data, content_type="application/json") self.assertEqual(response.status_code, 200) self.assertEqual(response.data["organization"], org1.pk) @@ -536,7 +538,7 @@ def test_patch_location_detail(self): self.assertEqual(l1.name, "test-location") path = reverse("geo_api:detail_location", args=[l1.pk]) data = {"name": "change-test-location"} - with self.assertNumQueries(5): + with self.assertNumQueries(9): response = self.client.patch(path, data, content_type="application/json") self.assertEqual(response.status_code, 200) self.assertEqual(response.data["name"], "change-test-location") @@ -566,7 +568,7 @@ def test_patch_floorplan_detail_api(self): fl = self._create_floorplan(location=l1) path = reverse("geo_api:detail_location", args=[l1.pk]) data = {"floorplan": {"floor": 13}} - with self.assertNumQueries(13): + with self.assertNumQueries(17): response = self.client.patch(path, data, content_type="application/json") self.assertEqual(response.status_code, 200) fl.refresh_from_db() @@ -577,7 +579,7 @@ def test_change_location_type_to_outdoor_api(self): self._create_floorplan(location=l1) path = reverse("geo_api:detail_location", args=[l1.pk]) data = {"type": "outdoor"} - with self.assertNumQueries(9): + with self.assertNumQueries(13): response = self.client.patch(path, data, content_type="application/json") self.assertEqual(response.status_code, 200) self.assertEqual(response.data["floorplan"], []) @@ -603,7 +605,7 @@ def test_create_location_with_floorplan(self): "floorplan.floor": ["23"], "floorplan.image": [fl_image], } - with self.assertNumQueries(16): + with self.assertNumQueries(20): response = self.client.post(path, data, format="multipart") self.assertEqual(response.status_code, 201) self.assertEqual(Location.objects.count(), 1) @@ -627,7 +629,7 @@ def test_create_new_floorplan_with_put_location_api(self): "floorplan.floor": "23", "floorplan.image": fl_image, } - with self.assertNumQueries(16): + with self.assertNumQueries(20): response = self.client.put( path, encode_multipart(BOUNDARY, data), content_type=MULTIPART_CONTENT ) @@ -722,7 +724,7 @@ def test_create_devicelocation_using_related_ids(self): floorplan = self._create_floorplan() location = floorplan.location url = reverse("geo_api:device_location", args=[device.id]) - with self.assertNumQueries(18): + with self.assertNumQueries(29): response = self.client.put( url, data={ @@ -760,7 +762,7 @@ def test_create_devicelocation_location_floorplan(self): "floorplan.image": self._get_simpleuploadedfile(), "indoor": ["12.342,23.541"], } - with self.assertNumQueries(32): + with self.assertNumQueries(43): response = self.client.put( url, encode_multipart(BOUNDARY, data), content_type=MULTIPART_CONTENT ) @@ -827,7 +829,7 @@ def test_create_devicelocation_only_location(self): "type": "indoor", } } - with self.assertNumQueries(21): + with self.assertNumQueries(32): response = self.client.put(url, data=data, content_type="application/json") self.assertEqual(response.status_code, 201) self.assertEqual(self.location_model.objects.count(), 1) @@ -867,7 +869,7 @@ def test_create_devicelocation_existing_location_new_floorplan(self): "floorplan.image": self._get_simpleuploadedfile(), "indoor": ["12.342,23.541"], } - with self.assertNumQueries(26): + with self.assertNumQueries(37): response = self.client.put( url, encode_multipart(BOUNDARY, data), content_type=MULTIPART_CONTENT ) @@ -890,7 +892,7 @@ def test_update_devicelocation_change_location_outdoor_to_indoor(self): } self.assertEqual(device_location.location.type, "outdoor") self.assertEqual(device_location.floorplan, None) - with self.assertNumQueries(23): + with self.assertNumQueries(33): response = self.client.put( path, encode_multipart(BOUNDARY, data), content_type=MULTIPART_CONTENT ) @@ -909,7 +911,7 @@ def test_update_devicelocation_patch_indoor(self): "indoor": "0,0", } self.assertEqual(device_location.indoor, "-140.38620,40.369227") - with self.assertNumQueries(12): + with self.assertNumQueries(20): response = self.client.patch(path, data, content_type="application/json") self.assertEqual(response.status_code, 200) device_location.refresh_from_db() @@ -926,7 +928,7 @@ def test_update_devicelocation_floorplan_related_id(self): data = { "floorplan": str(floor2.id), } - with self.assertNumQueries(14): + with self.assertNumQueries(22): response = self.client.patch(path, data, content_type="application/json") self.assertEqual(response.status_code, 200) device_location.refresh_from_db() @@ -940,7 +942,7 @@ def test_update_devicelocation_location_related_id(self): data = { "location": str(location2.id), } - with self.assertNumQueries(11): + with self.assertNumQueries(21): response = self.client.patch(path, data, content_type="application/json") self.assertEqual(response.status_code, 200) device_location.refresh_from_db() diff --git a/openwisp_controller/mixins.py b/openwisp_controller/mixins.py index db1378d92..d84cac9c6 100644 --- a/openwisp_controller/mixins.py +++ b/openwisp_controller/mixins.py @@ -1,3 +1,6 @@ +import reversion +from reversion.views import RevisionMixin + from openwisp_users.api.mixins import FilterByOrganizationManaged from openwisp_users.api.mixins import ProtectedAPIMixin as BaseProtectedAPIMixin from openwisp_users.api.permissions import DjangoModelPermissions, IsOrganizationManager @@ -35,3 +38,25 @@ class RelatedDeviceProtectedAPIMixin( class ProtectedAPIMixin(BaseProtectedAPIMixin, FilterByOrganizationManaged): pass + + +class AutoRevisionMixin(RevisionMixin): + revision_atomic = False + + def dispatch(self, request, *args, **kwargs): + qs = getattr(self, "queryset", None) + model = getattr(qs, "model", None) + if ( + request.method in ("POST", "PUT", "PATCH") + and request.user.is_authenticated + and model + and reversion.is_registered(model) + ): + with reversion.create_revision(atomic=self.revision_atomic): + response = super().dispatch(request, *args, **kwargs) + reversion.set_user(request.user) + reversion.set_comment( + f"API request: {request.method} {request.get_full_path()}" + ) + return response + return super().dispatch(request, *args, **kwargs) diff --git a/openwisp_controller/pki/api/views.py b/openwisp_controller/pki/api/views.py index 9d5ffdf03..e82399531 100644 --- a/openwisp_controller/pki/api/views.py +++ b/openwisp_controller/pki/api/views.py @@ -10,7 +10,7 @@ from rest_framework.response import Response from swapper import load_model -from ...mixins import ProtectedAPIMixin +from ...mixins import AutoRevisionMixin, ProtectedAPIMixin from .serializers import ( CaDetailSerializer, CaListSerializer, @@ -30,18 +30,18 @@ class ListViewPagination(pagination.PageNumberPagination): max_page_size = 100 -class CaListCreateView(ProtectedAPIMixin, ListCreateAPIView): +class CaListCreateView(ProtectedAPIMixin, AutoRevisionMixin, ListCreateAPIView): serializer_class = CaListSerializer queryset = Ca.objects.order_by("-created") pagination_class = ListViewPagination -class CaDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): +class CaDetailView(ProtectedAPIMixin, AutoRevisionMixin, RetrieveUpdateDestroyAPIView): serializer_class = CaDetailSerializer queryset = Ca.objects.all() -class CaRenewView(ProtectedAPIMixin, GenericAPIView): +class CaRenewView(ProtectedAPIMixin, AutoRevisionMixin, GenericAPIView): serializer_class = serializers.Serializer queryset = Ca.objects.all() @@ -66,18 +66,20 @@ def retrieve(self, request, *args, **kwargs): ) -class CertListCreateView(ProtectedAPIMixin, ListCreateAPIView): +class CertListCreateView(ProtectedAPIMixin, AutoRevisionMixin, ListCreateAPIView): serializer_class = CertListSerializer queryset = Cert.objects.order_by("-created") pagination_class = ListViewPagination -class CertDetailView(ProtectedAPIMixin, RetrieveUpdateDestroyAPIView): +class CertDetailView( + ProtectedAPIMixin, AutoRevisionMixin, RetrieveUpdateDestroyAPIView +): serializer_class = CertDetailSerializer queryset = Cert.objects.select_related("ca") -class CertRevokeRenewBaseView(ProtectedAPIMixin, GenericAPIView): +class CertRevokeRenewBaseView(ProtectedAPIMixin, AutoRevisionMixin, GenericAPIView): serializer_class = serializers.Serializer queryset = Cert.objects.select_related("ca") diff --git a/openwisp_controller/pki/tests/test_api.py b/openwisp_controller/pki/tests/test_api.py index d2d36f1d4..05b2300b2 100644 --- a/openwisp_controller/pki/tests/test_api.py +++ b/openwisp_controller/pki/tests/test_api.py @@ -51,7 +51,7 @@ def test_ca_post_api(self): self.assertEqual(Ca.objects.count(), 0) path = reverse("pki_api:ca_list") data = self._ca_data - with self.assertNumQueries(4): + with self.assertNumQueries(7): r = self.client.post(path, data, content_type="application/json") self.assertEqual(r.status_code, 201) self.assertEqual(Ca.objects.count(), 1) @@ -61,7 +61,7 @@ def test_ca_post_with_extensions_field(self): path = reverse("pki_api:ca_list") data = self._ca_data data["extensions"] = [] - with self.assertNumQueries(4): + with self.assertNumQueries(7): r = self.client.post(path, data, content_type="application/json") self.assertEqual(r.status_code, 201) self.assertEqual(r.data["extensions"], []) @@ -77,7 +77,7 @@ def test_ca_import_post_api(self): "private_key": ca1.private_key, } expected_queries = ( - 7 if parse_version(REST_FRAMEWORK_VERSION) >= parse_version("3.15") else 6 + 10 if parse_version(REST_FRAMEWORK_VERSION) >= parse_version("3.15") else 9 ) with self.assertNumQueries(expected_queries): r = self.client.post(path, data, content_type="application/json") @@ -97,7 +97,7 @@ def test_ca_post_with_date_none_api(self): "validity_start": None, "validity_end": None, } - with self.assertNumQueries(4): + with self.assertNumQueries(7): r = self.client.post(path, data, content_type="application/json") self.assertEqual(r.status_code, 201) self.assertEqual(Ca.objects.count(), 1) @@ -124,7 +124,7 @@ def test_ca_put_api(self): path = reverse("pki_api:ca_detail", args=[ca1.pk]) org2 = self._create_org() data = {"name": "change-ca1", "organization": org2.pk, "notes": "change-notes"} - with self.assertNumQueries(8): + with self.assertNumQueries(11): r = self.client.put(path, data, content_type="application/json") self.assertEqual(r.status_code, 200) self.assertEqual(r.data["name"], "change-ca1") @@ -137,7 +137,7 @@ def test_ca_patch_api(self): data = { "name": "change-ca1", } - with self.assertNumQueries(7): + with self.assertNumQueries(10): r = self.client.patch(path, data, content_type="application/json") self.assertEqual(r.status_code, 200) self.assertEqual(r.data["name"], "change-ca1") @@ -161,7 +161,7 @@ def test_ca_post_renew_api(self): ca1 = self._create_ca(name="ca1", organization=self._get_org()) old_serial_num = ca1.serial_number path = reverse("pki_api:ca_renew", args=[ca1.pk]) - with self.assertNumQueries(5): + with self.assertNumQueries(8): r = self.client.post(path) ca1.refresh_from_db() self.assertEqual(r.status_code, 200) @@ -172,7 +172,7 @@ def test_cert_post_api(self): path = reverse("pki_api:cert_list") data = self._cert_data data["ca"] = self._create_ca().pk - with self.assertNumQueries(8): + with self.assertNumQueries(11): r = self.client.post(path, data, content_type="application/json") self.assertEqual(r.status_code, 201) self.assertEqual(Cert.objects.count(), 1) @@ -189,7 +189,7 @@ def test_import_cert_post_api(self): "private_key": ca1.private_key, } expected_queries = ( - 11 if parse_version(REST_FRAMEWORK_VERSION) >= parse_version("3.15") else 10 + 14 if parse_version(REST_FRAMEWORK_VERSION) >= parse_version("3.15") else 13 ) with self.assertNumQueries(expected_queries): r = self.client.post(path, data, content_type="application/json") @@ -205,7 +205,7 @@ def test_cert_post_with_extensions_field(self): data = self._cert_data data["ca"] = self._create_ca().pk data["extensions"] = [] - with self.assertNumQueries(8): + with self.assertNumQueries(11): r = self.client.post(path, data, content_type="application/json") self.assertEqual(r.status_code, 201) self.assertEqual(Cert.objects.count(), 1) @@ -221,7 +221,7 @@ def test_cert_post_with_date_none(self): "validity_start": None, "validity_end": None, } - with self.assertNumQueries(8): + with self.assertNumQueries(11): r = self.client.post(path, data, content_type="application/json") self.assertEqual(r.status_code, 201) self.assertEqual(Cert.objects.count(), 1) @@ -253,7 +253,7 @@ def test_cert_put_api(self): "organization": org2.pk, "notes": "new-notes", } - with self.assertNumQueries(10): + with self.assertNumQueries(13): r = self.client.put(path, data, content_type="application/json") self.assertEqual(r.status_code, 200) self.assertEqual(r.data["name"], "cert1-change") @@ -264,7 +264,7 @@ def test_cert_patch_api(self): cert1 = self._create_cert(name="cert1") path = reverse("pki_api:cert_detail", args=[cert1.pk]) data = {"name": "cert1-change"} - with self.assertNumQueries(9): + with self.assertNumQueries(12): r = self.client.patch(path, data, content_type="application/json") self.assertEqual(r.status_code, 200) self.assertEqual(r.data["name"], "cert1-change") @@ -289,7 +289,7 @@ def test_post_cert_renew_api(self): cert1 = self._create_cert(name="cert1") old_serial_num = cert1.serial_number path = reverse("pki_api:cert_renew", args=[cert1.pk]) - with self.assertNumQueries(6): + with self.assertNumQueries(9): r = self.client.post(path) self.assertEqual(r.status_code, 200) cert1.refresh_from_db() @@ -300,7 +300,7 @@ def test_post_cert_revoke_api(self): cert1 = self._create_cert(name="cert1") self.assertFalse(cert1.revoked) path = reverse("pki_api:cert_revoke", args=[cert1.pk]) - with self.assertNumQueries(5): + with self.assertNumQueries(8): r = self.client.post(path) cert1.refresh_from_db() self.assertEqual(r.status_code, 200) diff --git a/tests/openwisp2/sample_config/api/views.py b/tests/openwisp2/sample_config/api/views.py index daa8e2feb..e81ce1bbc 100644 --- a/tests/openwisp2/sample_config/api/views.py +++ b/tests/openwisp2/sample_config/api/views.py @@ -28,12 +28,21 @@ from openwisp_controller.config.api.views import ( DeviceListCreateView as BaseDeviceListCreateView, ) +from openwisp_controller.config.api.views import ( + RevisionListView as BaseRevisionListView, +) +from openwisp_controller.config.api.views import ( + RevisionRestoreView as BaseRevisionRestoreView, +) from openwisp_controller.config.api.views import ( TemplateDetailView as BaseTemplateDetailView, ) from openwisp_controller.config.api.views import ( TemplateListCreateView as BaseTemplateListCreateView, ) +from openwisp_controller.config.api.views import ( + VersionDetailView as BaseVersionDetailView, +) from openwisp_controller.config.api.views import VpnDetailView as BaseVpnDetailView from openwisp_controller.config.api.views import ( VpnListCreateView as BaseVpnListCreateView, @@ -96,6 +105,18 @@ class DownloadDeviceView(BaseDownloadDeviceView): pass +class RevisionListView(BaseRevisionListView): + pass + + +class VersionDetailView(BaseVersionDetailView): + pass + + +class RevisionRestoreView(BaseRevisionRestoreView): + pass + + template_list = TemplateListCreateView.as_view() template_detail = TemplateDetailView.as_view() download_template_config = DownloadTemplateconfiguration.as_view() @@ -110,3 +131,6 @@ class DownloadDeviceView(BaseDownloadDeviceView): devicegroup_list = DeviceGroupListCreateView.as_view() devicegroup_detail = DeviceGroupDetailView.as_view() devicegroup_commonname = DeviceGroupCommonName.as_view() +revision_list = RevisionListView.as_view() +version_detail = VersionDetailView.as_view() +revision_restore = RevisionRestoreView.as_view()