|
| 1 | +from typing import Literal |
| 2 | +from unittest.mock import AsyncMock |
| 3 | + |
| 4 | +import pytest |
| 5 | + |
1 | 6 | from infrahub.core.branch import Branch |
2 | 7 | from infrahub.core.changelog.diff import DiffChangelogCollector |
3 | 8 | from infrahub.core.changelog.models import RelationshipCardinalityManyChangelog, RelationshipCardinalityOneChangelog |
4 | 9 | from infrahub.core.constants import DiffAction |
5 | 10 | from infrahub.core.diff.coordinator import DiffCoordinator |
| 11 | +from infrahub.core.diff.data_check_synchronizer import DiffDataCheckSynchronizer |
6 | 12 | from infrahub.core.diff.merger.merger import DiffMerger |
| 13 | +from infrahub.core.diff.model.path import ConflictSelection |
| 14 | +from infrahub.core.diff.repository.repository import DiffRepository |
7 | 15 | from infrahub.core.initialization import create_branch |
8 | 16 | from infrahub.core.manager import NodeManager |
9 | 17 | from infrahub.core.node import Node |
@@ -133,3 +141,74 @@ async def test_merge_diff_changelogs(db: InfrahubDatabase, default_branch, car_p |
133 | 141 | assert c2_changelog.relationships["owner"].properties["owner"].value_previous is None |
134 | 142 | assert c2_changelog.relationships["owner"].properties["source"].value == p1.id |
135 | 143 | assert c2_changelog.relationships["owner"].properties["source"].value_previous is None |
| 144 | + |
| 145 | + |
| 146 | +class TestConflict: |
| 147 | + async def _get_diff_coordinator(self, db: InfrahubDatabase, branch: Branch) -> DiffCoordinator: |
| 148 | + component_registry = get_component_registry() |
| 149 | + diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=branch) |
| 150 | + diff_coordinator.data_check_synchronizer = AsyncMock(spec=DiffDataCheckSynchronizer) |
| 151 | + return diff_coordinator |
| 152 | + |
| 153 | + async def _get_diff_merger(self, db: InfrahubDatabase, branch: Branch) -> DiffMerger: |
| 154 | + component_registry = get_component_registry() |
| 155 | + return await component_registry.get_component(DiffMerger, db=db, branch=branch) |
| 156 | + |
| 157 | + @pytest.fixture |
| 158 | + async def diff_repository(self, db: InfrahubDatabase, default_branch: Branch) -> DiffRepository: |
| 159 | + component_registry = get_component_registry() |
| 160 | + return await component_registry.get_component(DiffRepository, db=db, branch=default_branch) |
| 161 | + |
| 162 | + @pytest.mark.parametrize( |
| 163 | + "conflict_selection,expected_value", |
| 164 | + [(ConflictSelection.BASE_BRANCH, "John-main"), (ConflictSelection.DIFF_BRANCH, "John-branch")], |
| 165 | + ) |
| 166 | + async def test_diff_and_merge_with_attribute_value_conflict( |
| 167 | + self, |
| 168 | + db: InfrahubDatabase, |
| 169 | + default_branch: Branch, |
| 170 | + diff_repository: DiffRepository, |
| 171 | + person_john_main: Node, |
| 172 | + person_jane_main: Node, |
| 173 | + person_alfred_main: Node, |
| 174 | + car_accord_main: Node, |
| 175 | + conflict_selection: ConflictSelection, |
| 176 | + expected_value: Literal["John-main", "John-branch"], |
| 177 | + ): |
| 178 | + branch2 = await create_branch(db=db, branch_name="branch2") |
| 179 | + john_main = await NodeManager.get_one(db=db, id=person_john_main.id) |
| 180 | + john_main.name.value = "John-main" |
| 181 | + await john_main.save(db=db) |
| 182 | + john_branch = await NodeManager.get_one(db=db, branch=branch2, id=person_john_main.id) |
| 183 | + john_branch.name.value = "John-branch" |
| 184 | + await john_branch.save(db=db) |
| 185 | + |
| 186 | + at = Timestamp() |
| 187 | + diff_coordinator = await self._get_diff_coordinator(db=db, branch=branch2) |
| 188 | + enriched_diff_metadata = await diff_coordinator.update_branch_diff( |
| 189 | + base_branch=default_branch, diff_branch=branch2 |
| 190 | + ) |
| 191 | + enriched_diff = await diff_repository.get_one( |
| 192 | + diff_branch_name=enriched_diff_metadata.diff_branch_name, diff_id=enriched_diff_metadata.uuid |
| 193 | + ) |
| 194 | + conflicts_map = enriched_diff.get_all_conflicts() |
| 195 | + assert len(conflicts_map) == 1 |
| 196 | + conflict = next(iter(conflicts_map.values())) |
| 197 | + await diff_repository.update_conflict_by_id(conflict_id=conflict.uuid, selection=conflict_selection) |
| 198 | + diff_merger = await self._get_diff_merger(db=db, branch=branch2) |
| 199 | + diff = await diff_merger.merge_graph(at=at) |
| 200 | + diff_events = DiffChangelogCollector(diff=diff, db=db, branch=branch2) |
| 201 | + events = diff_events.collect_changelogs() |
| 202 | + |
| 203 | + match conflict_selection: |
| 204 | + case ConflictSelection.BASE_BRANCH: |
| 205 | + # When we want to keep the conflict in the base branch we don't expect to see any updates after the merge |
| 206 | + assert len(events) == 0 |
| 207 | + case ConflictSelection.DIFF_BRANCH: |
| 208 | + # Expect to see changes on the diff branch when we keep changes from that branch |
| 209 | + assert len(events) == 1 |
| 210 | + event = events[0] |
| 211 | + action, node_changelog = event |
| 212 | + assert action == DiffAction.UPDATED |
| 213 | + assert node_changelog.attributes["name"].value == "John-branch" |
| 214 | + assert node_changelog.attributes["name"].value_previous == "John" |
0 commit comments