diff --git a/api/environments/serializers.py b/api/environments/serializers.py index 3448b2b85374..e1702c4d3bc4 100644 --- a/api/environments/serializers.py +++ b/api/environments/serializers.py @@ -4,7 +4,7 @@ from environments.models import Environment, EnvironmentAPIKey, Webhook from features.serializers import FeatureStateSerializerFull -from metadata.serializers import MetadataSerializer, MetadataSerializerMixin +from metadata.serializers import MetadataSerializerMixin from organisations.models import Subscription from organisations.subscriptions.serializers.mixins import ( ReadOnlyIfNotValidPlanMixin, @@ -79,8 +79,6 @@ class EnvironmentSerializerWithMetadata( DeleteBeforeUpdateWritableNestedModelSerializer, EnvironmentSerializerLight, ): - metadata = MetadataSerializer(required=False, many=True) - class Meta(EnvironmentSerializerLight.Meta): fields = EnvironmentSerializerLight.Meta.fields + ("metadata",) # type: ignore[assignment] diff --git a/api/features/serializers.py b/api/features/serializers.py index d922dd11ae47..5de6c17a44af 100644 --- a/api/features/serializers.py +++ b/api/features/serializers.py @@ -23,7 +23,7 @@ ) from integrations.github.constants import GitHubEventType from integrations.github.github import call_github_task -from metadata.serializers import MetadataSerializer, MetadataSerializerMixin +from metadata.serializers import MetadataSerializerMixin from projects.code_references.serializers import ( FeatureFlagCodeReferencesRepositoryCountSerializer, ) @@ -345,8 +345,6 @@ def get_last_modified_in_current_environment( class FeatureSerializerWithMetadata(MetadataSerializerMixin, CreateFeatureSerializer): - metadata = MetadataSerializer(required=False, many=True) - code_references_counts = FeatureFlagCodeReferencesRepositoryCountSerializer( many=True, read_only=True, diff --git a/api/import_export/export.py b/api/import_export/export.py index 30cb541680e9..5531dd88b9b5 100644 --- a/api/import_export/export.py +++ b/api/import_export/export.py @@ -8,7 +8,7 @@ import boto3 from django.core import serializers from django.core.serializers.json import DjangoJSONEncoder -from django.db.models import F, Model, Q +from django.db.models import Model, Q from edge_api.identities.export import export_edge_identity_and_overrides from environments.identities.models import Identity @@ -130,28 +130,28 @@ def export_projects( *_export_entities( _EntityExportConfig( Segment, - Q(project__organisation__id=organisation_id, id=F("version_of")), + Q(project__organisation__id=organisation_id, version_of__isnull=True), ), _EntityExportConfig( SegmentRule, Q( segment__project__organisation__id=organisation_id, - segment_id=F("segment__version_of"), + segment__version_of__isnull=True, ) | Q( rule__segment__project__organisation__id=organisation_id, - rule__segment_id=F("rule__segment__version_of"), + rule__segment__version_of__isnull=True, ), ), _EntityExportConfig( Condition, Q( rule__segment__project__organisation__id=organisation_id, - rule__segment_id=F("rule__segment__version_of"), + rule__segment__version_of__isnull=True, ) | Q( rule__rule__segment__project__organisation__id=organisation_id, - rule__rule__segment_id=F("rule__rule__segment__version_of"), + rule__rule__segment__version_of__isnull=True, ), ), _EntityExportConfig(Tag, default_filter), diff --git a/api/metadata/serializers.py b/api/metadata/serializers.py index 8cfca2adb5b4..104d1241e821 100644 --- a/api/metadata/serializers.py +++ b/api/metadata/serializers.py @@ -103,11 +103,15 @@ def validate(self, attrs: dict[str, Any]) -> dict[str, Any]: return attrs -class MetadataSerializerMixin: +class MetadataSerializerMixin(serializers.Serializer): # type: ignore[type-arg] """ - Functionality for serializers that need to handle metadata + Mixin for serializers that need to handle metadata + + NOTE: Child serializers should include 'metadata' in their Meta.fields. """ + metadata = MetadataSerializer(required=False, many=True) + def _validate_required_metadata( self, organisation: Organisation, metadata: list[dict[str, Any]] ) -> None: diff --git a/api/segments/managers.py b/api/segments/managers.py deleted file mode 100644 index d1adddb94771..000000000000 --- a/api/segments/managers.py +++ /dev/null @@ -1,16 +0,0 @@ -from django.db.models import F - -from core.models import SoftDeleteExportableManager - - -class SegmentManager(SoftDeleteExportableManager): - pass - - -class LiveSegmentManager(SoftDeleteExportableManager): - def get_queryset(self): # type: ignore[no-untyped-def] - """ - Returns only the canonical segments, which will always be - the highest version. - """ - return super().get_queryset().filter(id=F("version_of")) diff --git a/api/segments/migrations/0030_add_default_to_segment_version.py b/api/segments/migrations/0030_add_default_to_segment_version.py new file mode 100644 index 000000000000..4813ecd210a5 --- /dev/null +++ b/api/segments/migrations/0030_add_default_to_segment_version.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.22 on 2025-11-11 00:08 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("segments", "0029_add_is_system_segment"), + ] + + operations = [ + migrations.AlterField( + model_name="historicalsegment", + name="version", + field=models.IntegerField(default=1, null=True), + ), + migrations.AlterField( + model_name="segment", + name="version", + field=models.IntegerField(default=1, null=True), + ), + ] diff --git a/api/segments/migrations/0031_set_version_of_to_null_for_canonical_segments.py b/api/segments/migrations/0031_set_version_of_to_null_for_canonical_segments.py new file mode 100644 index 000000000000..5eb0ab713bdb --- /dev/null +++ b/api/segments/migrations/0031_set_version_of_to_null_for_canonical_segments.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.22 on 2025-11-11 03:43 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ("segments", "0030_add_default_to_segment_version"), + ] + + operations = [ + # Set version_of to NULL for canonical segments (where version_of_id = id). + # This follows the same pattern as migration 0023 which originally set + # version_of_id = id. Like that migration, this may block during deployment + # depending on the number of canonical segments in the database. + migrations.RunSQL( + sql="UPDATE segments_segment SET version_of_id = NULL WHERE version_of_id = id;", + reverse_sql="UPDATE segments_segment SET version_of_id = id WHERE version_of_id IS NULL;", + ), + ] diff --git a/api/segments/models.py b/api/segments/models.py index c1892383eb55..34fec845f782 100644 --- a/api/segments/models.py +++ b/api/segments/models.py @@ -8,10 +8,7 @@ from django.core.exceptions import ValidationError from django.db import models, transaction from django_lifecycle import ( # type: ignore[import-untyped] - AFTER_CREATE, - BEFORE_CREATE, LifecycleModelMixin, - hook, ) from flag_engine.segments import constants @@ -30,13 +27,37 @@ from metadata.models import Metadata from projects.models import Project -from .managers import LiveSegmentManager, SegmentManager - ModelT = typing.TypeVar("ModelT", bound=models.Model) logger = logging.getLogger(__name__) +class LiveSegmentManager(SoftDeleteExportableManager): + def get_queryset(self): # type: ignore[no-untyped-def] + """ + Returns only canonical segments (where version_of is NULL). + Canonical segments represent the current/live version. + """ + return super().get_queryset().filter(version_of__isnull=True) + + +class RevisionsManager(SoftDeleteExportableManager): + def get_queryset(self): # type: ignore[no-untyped-def] + """ + Returns only segment revisions (where version_of is NOT NULL). + Revisions are historical versions of segments. + """ + return super().get_queryset().filter(version_of__isnull=False) + + +class AllSegmentsManager(SoftDeleteExportableManager): + """ + Returns all segments (both canonical and revisions). + Only filters out soft-deleted segments. + """ + pass + + class ConfiguredOrderManager(SoftDeleteExportableManager, models.Manager[ModelT]): setting_name: str @@ -87,8 +108,7 @@ class Segment( Feature, on_delete=models.CASCADE, related_name="segments", null=True ) - # This defaults to 1 for newly created segments. - version = models.IntegerField(null=True) + version = models.IntegerField(default=1, null=True) version_of = models.ForeignKey( "self", @@ -112,10 +132,11 @@ class Segment( updated_at = models.DateTimeField(null=True, auto_now=True) is_system_segment = models.BooleanField(default=False) - objects = SegmentManager() # type: ignore[misc] - - # Only serves segments that are the canonical version. - live_objects = LiveSegmentManager() + # Manager declarations - order matters! First manager is the base_manager used in relations. + objects = LiveSegmentManager() # type: ignore[misc] # Default: canonical segments only + live_objects = objects # Explicit alias for clarity + revisions = RevisionsManager() # type: ignore[misc] # Only historical versions + all_objects = AllSegmentsManager() # type: ignore[misc] # Both canonical and revisions class Meta: ordering = ("id",) # explicit ordering to prevent pagination warnings @@ -126,27 +147,8 @@ def __str__(self): # type: ignore[no-untyped-def] def get_skip_create_audit_log(self) -> bool: if self.is_system_segment: return True - try: - if self.version_of_id and self.version_of_id != self.id: - return True - except Segment.DoesNotExist: - return True - - return False - - @hook(BEFORE_CREATE, when="version_of", is_now=None) - def set_default_version_to_one_if_new_segment(self): # type: ignore[no-untyped-def] - if self.version is None: - self.version = 1 - - @hook(AFTER_CREATE, when="version_of", is_now=None) - def set_version_of_to_self_if_none(self): # type: ignore[no-untyped-def] - """ - This allows the segment model to reference all versions of - itself including itself. - """ - self.version_of = self - self.save_without_historical_record() + is_revision = self.version_of_id is not None + return is_revision @transaction.atomic def clone(self, is_revision: bool = False, **extra_attrs: typing.Any) -> "Segment": @@ -165,7 +167,7 @@ def clone(self, is_revision: bool = False, **extra_attrs: typing.Any) -> "Segmen cloned_segment.copy_rules_and_conditions_from(self) # Handle versioning - version_of = self if is_revision else cloned_segment + version_of = self if is_revision else None cloned_segment.version_of = extra_attrs.get("version_of", version_of) cloned_segment.version = self.version if is_revision else 1 Segment.objects.filter(pk=cloned_segment.pk).update( diff --git a/api/segments/serializers.py b/api/segments/serializers.py index 4cbc203894f4..049961f17f3e 100644 --- a/api/segments/serializers.py +++ b/api/segments/serializers.py @@ -7,7 +7,7 @@ from rest_framework import serializers from rest_framework.exceptions import ValidationError -from metadata.serializers import MetadataSerializer, MetadataSerializerMixin +from metadata.serializers import MetadataSerializerMixin from projects.models import Project from segments.models import Condition, Segment, SegmentRule @@ -80,7 +80,6 @@ class Meta: class SegmentSerializer(MetadataSerializerMixin, WritableNestedModelSerializer): rules = SegmentRuleSerializer(many=True, required=True, allow_empty=False) - metadata = MetadataSerializer(required=False, many=True) def __init__(self, *args: Any, **kwargs: Any) -> None: """ diff --git a/api/tests/unit/segments/test_unit_segments_migrations.py b/api/tests/unit/segments/test_unit_segments_migrations.py index 78ddbe253fe7..75650b626b9b 100644 --- a/api/tests/unit/segments/test_unit_segments_migrations.py +++ b/api/tests/unit/segments/test_unit_segments_migrations.py @@ -155,7 +155,7 @@ def test_add_versioning_to_segments_forwards(migrator: Migrator) -> None: # Then the version_of attribute is correctly set. NewSegment = new_state.apps.get_model("segments", "Segment") new_segment = NewSegment.objects.get(id=segment.id) - assert new_segment.version_of == new_segment + assert new_segment.version_of_id == new_segment.id @pytest.mark.skipif( diff --git a/api/tests/unit/segments/test_unit_segments_models.py b/api/tests/unit/segments/test_unit_segments_models.py index bb689d0239ff..02c221e72f03 100644 --- a/api/tests/unit/segments/test_unit_segments_models.py +++ b/api/tests/unit/segments/test_unit_segments_models.py @@ -1,6 +1,5 @@ from collections.abc import Callable from typing import Any -from unittest.mock import PropertyMock import pytest from django.core.exceptions import ValidationError @@ -29,7 +28,15 @@ def test_Condition_str__returns_readable_representation_of_condition( assert result == "Condition for ALL rule for Segment - segment: foo EQUAL bar" -def test_Condition_get_skip_create_audit_log__returns_true( +@pytest.mark.parametrize( + "delete", + [ + lambda rule: rule.delete(), + lambda rule: rule.hard_delete(), + ], +) +def test_Condition_get_skip_create_audit_log__rule_deleted__returns_true( + delete: Callable[[SegmentRule], None], segment_rule: SegmentRule, ) -> None: # Given @@ -38,16 +45,45 @@ def test_Condition_get_skip_create_audit_log__returns_true( property="foo", operator=EQUAL, value="bar", + created_with_segment=False, ) # When - result = condition.get_skip_create_audit_log() + delete(segment_rule) # Then - assert result is True + assert condition.get_skip_create_audit_log() is True -def test_manager_returns_only_highest_version_of_segments( +@pytest.mark.parametrize( + "delete", + [ + lambda segment: segment.delete(), + lambda segment: segment.hard_delete(), + ], +) +def test_Condition_get_skip_create_audit_log__segment_deleted__returns_true( + delete: Callable[[Segment], None], + segment: Segment, + segment_rule: SegmentRule, +) -> None: + # Given + condition = Condition.objects.create( + rule=segment_rule, + property="foo", + operator=EQUAL, + value="bar", + created_with_segment=False, + ) + + # When + delete(segment) + + # Then + assert condition.get_skip_create_audit_log() is True + + +def test_LiveSegmentManager__returns_only_highest_version_of_segments( segment: Segment, ) -> None: # Given @@ -108,24 +144,7 @@ def test_SegmentRule_get_skip_create_audit_log__returns_true( assert result is True -def test_segment_get_skip_create_audit_log_when_exception( - mocker: MockerFixture, - segment: Segment, -) -> None: - # Given - patched_segment = mocker.patch.object( - Segment, "version_of_id", new_callable=PropertyMock - ) - patched_segment.side_effect = Segment.DoesNotExist("Segment missing") - - # When - result = segment.get_skip_create_audit_log() - - # Then - assert result is True - - -def test_delete_segment_only_schedules_one_task_for_audit_log_creation( +def test_Segment_delete__multiple_rules_conditions__schedules_audit_log_task_once( mocker: MockerFixture, segment: Segment ) -> None: # Given @@ -143,11 +162,11 @@ def test_delete_segment_only_schedules_one_task_for_audit_log_creation( ) # When - mocked_tasks = mocker.patch("core.signals.tasks") + task = mocker.patch("core.signals.tasks.create_audit_log_from_historical_record") segment.delete() # Then - assert len(mocked_tasks.mock_calls) == 1 + assert task.delay.call_count == 1 def test_Segment_clone__can_create_standalone_segment_clone( @@ -163,7 +182,7 @@ def test_Segment_clone__can_create_standalone_segment_clone( # Then assert cloned_segment != segment assert cloned_segment.name == "another-segment" - assert cloned_segment.version_of == cloned_segment + assert cloned_segment.version_of is None assert cloned_segment.version == 1 @@ -264,7 +283,9 @@ def test_Segment_clone__segment_with_rules__returns_new_segment_with_copied_rule ] -def test_system_segment_get_skip_create_audit_log(system_segment: Segment) -> None: +def test_Segment_get_skip_create_audit_log__system_segment__returns_true( + system_segment: Segment, +) -> None: # When result = system_segment.get_skip_create_audit_log() diff --git a/api/tests/unit/segments/test_unit_segments_views.py b/api/tests/unit/segments/test_unit_segments_views.py index 57cc212884ed..3520902efb6a 100644 --- a/api/tests/unit/segments/test_unit_segments_views.py +++ b/api/tests/unit/segments/test_unit_segments_views.py @@ -207,7 +207,7 @@ def test_segments_limit_ignores_old_segment_versions( # and create some older versions for the segment fixture segment.clone(is_revision=True) - assert Segment.objects.filter(version_of=segment).count() == 2 + assert Segment.revisions.filter(version_of=segment).count() == 1 assert Segment.live_objects.count() == 1 url = reverse("api-v1:projects:project-segments-list", args=[project.id]) @@ -1058,11 +1058,11 @@ def test_update_segment_versioned_segment( assert response.status_code == status.HTTP_200_OK # Now verify that a new versioned segment has been set. - assert Segment.objects.filter(version_of=segment).count() == 2 + assert Segment.revisions.filter(version_of=segment).count() == 1 # Now check the previously versioned segment to match former count of conditions. - versioned_segment = Segment.objects.filter(version_of=segment, version=1).first() + versioned_segment = Segment.revisions.filter(version_of=segment, version=1).first() assert versioned_segment != segment assert versioned_segment.rules.count() == 1 versioned_rule = versioned_segment.rules.first() @@ -1093,7 +1093,8 @@ def test_update_segment_versioned_segment_with_thrown_exception( rule=nested_rule, property="foo", operator=EQUAL, value="bar" ) - assert segment.version == 1 == Segment.objects.filter(version_of=segment).count() + assert segment.version == 1 + assert Segment.revisions.filter(version_of=segment).count() == 0 new_condition_property = "foo2" new_condition_value = "bar" @@ -1144,7 +1145,9 @@ def test_update_segment_versioned_segment_with_thrown_exception( segment.refresh_from_db() # Now verify that the version of the segment has not been changed. - assert segment.version == 1 == Segment.objects.filter(version_of=segment).count() + # The transaction should have rolled back, so no revisions should exist. + assert segment.version == 1 + assert Segment.revisions.filter(version_of=segment).count() == 0 @pytest.mark.parametrize( diff --git a/api/util/mappers/engine.py b/api/util/mappers/engine.py index a8ced81fbdb3..b2ddea2aeff3 100644 --- a/api/util/mappers/engine.py +++ b/api/util/mappers/engine.py @@ -201,10 +201,9 @@ def map_environment_to_engine( organisation: "Organisation" = project.organisation # Read relationships - grab all the data needed from the ORM here. - - project_segments = [ - ps for ps in project.segments.all() if ps.id == ps.version_of_id - ] + # Note: project.segments uses Segment's base manager (LiveSegmentManager), + # which returns only canonical segments. + project_segments = list(project.segments.all()) project_segment_rules_by_segment_id: Dict[ int,