Skip to content

Commit ad47864

Browse files
authored
add merge rollback logic (#4818)
* add merge rollback logic * new integration test for rollback * format
1 parent 353abd9 commit ad47864

File tree

9 files changed

+330
-27
lines changed

9 files changed

+330
-27
lines changed

backend/infrahub/core/branch/tasks.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from infrahub.core.validators.tasks import schema_validate_migrations
1818
from infrahub.dependencies.registry import get_component_registry
1919
from infrahub.events.branch_action import BranchDeleteEvent
20-
from infrahub.exceptions import ValidationError
20+
from infrahub.exceptions import MergeFailedError, ValidationError
2121
from infrahub.log import get_log_data
2222
from infrahub.message_bus import Meta, messages
2323
from infrahub.services import services
@@ -138,7 +138,7 @@ async def rebase_branch(branch: str) -> None:
138138

139139

140140
@flow(name="branch-merge")
141-
async def merge_branch(branch: str, conflict_resolution: dict[str, bool] | None = None) -> None:
141+
async def merge_branch(branch: str) -> None:
142142
service = services.service
143143
log = get_run_logger()
144144

@@ -150,14 +150,21 @@ async def merge_branch(branch: str, conflict_resolution: dict[str, bool] | None
150150

151151
merger: BranchMerger | None = None
152152
async with lock.registry.global_graph_lock():
153-
async with service.database.start_transaction() as db:
154-
diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=obj)
155-
diff_merger = await component_registry.get_component(DiffMerger, db=db, branch=obj)
156-
merger = BranchMerger(
157-
db=db, diff_coordinator=diff_coordinator, diff_merger=diff_merger, source_branch=obj, service=service
158-
)
159-
await merger.merge(conflict_resolution=conflict_resolution)
160-
await merger.update_schema()
153+
diff_coordinator = await component_registry.get_component(DiffCoordinator, db=service.database, branch=obj)
154+
diff_merger = await component_registry.get_component(DiffMerger, db=service.database, branch=obj)
155+
merger = BranchMerger(
156+
db=service.database,
157+
diff_coordinator=diff_coordinator,
158+
diff_merger=diff_merger,
159+
source_branch=obj,
160+
service=service,
161+
)
162+
try:
163+
await merger.merge()
164+
except Exception as exc:
165+
await merger.rollback()
166+
raise MergeFailedError(branch_name=branch) from exc
167+
await merger.update_schema()
161168

162169
if merger and merger.migrations:
163170
errors = await schema_apply_migrations(

backend/infrahub/core/diff/merger/merger.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from infrahub.core import registry
66
from infrahub.core.diff.model.path import BranchTrackingId
7-
from infrahub.core.diff.query.merge import DiffMergePropertiesQuery, DiffMergeQuery
7+
from infrahub.core.diff.query.merge import DiffMergePropertiesQuery, DiffMergeQuery, DiffMergeRollbackQuery
88

99
if TYPE_CHECKING:
1010
from infrahub.core.branch import Branch
@@ -65,3 +65,9 @@ async def merge_graph(self, at: Timestamp) -> None:
6565
self.source_branch.branched_from = at.to_string()
6666
await self.source_branch.save(db=self.db)
6767
registry.branch[self.source_branch.name] = self.source_branch
68+
69+
async def rollback(self, at: Timestamp) -> None:
70+
rollback_query = await DiffMergeRollbackQuery.init(
71+
db=self.db, branch=self.source_branch, target_branch=self.destination_branch, at=at
72+
)
73+
await rollback_query.execute(db=self.db)

backend/infrahub/core/diff/query/merge.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,17 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
166166
CALL {
167167
WITH n, rel_name, rel_peer_id, related_rel_status
168168
MATCH (n)
169-
-[source_r_rel_1:IS_RELATED {branch: $source_branch}]
169+
-[source_r_rel_1:IS_RELATED]
170170
-(r:Relationship {name: rel_name})
171-
-[source_r_rel_2:IS_RELATED {branch: $source_branch}]
171+
-[source_r_rel_2:IS_RELATED]
172172
-(:Node {uuid: rel_peer_id})
173-
WHERE source_r_rel_1.from <= $at AND source_r_rel_1.to IS NULL
173+
WHERE source_r_rel_1.branch IN [$source_branch, $target_branch]
174+
AND source_r_rel_2.branch IN [$source_branch, $target_branch]
175+
AND source_r_rel_1.from <= $at AND source_r_rel_1.to IS NULL
174176
AND source_r_rel_2.from <= $at AND source_r_rel_2.to IS NULL
177+
WITH n, rel_name, rel_peer_id, related_rel_status, r, source_r_rel_1, source_r_rel_2
178+
ORDER BY source_r_rel_1.branch_level DESC, source_r_rel_2.branch_level DESC, source_r_rel_1.from DESC, source_r_rel_2.from DESC
179+
LIMIT 1
175180
RETURN r, CASE
176181
WHEN startNode(source_r_rel_1).uuid = n.uuid THEN "r"
177182
ELSE "l"
@@ -433,3 +438,44 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
433438
}
434439
"""
435440
self.add_to_query(query=query)
441+
442+
443+
class DiffMergeRollbackQuery(Query):
444+
name = "diff_merge_rollback"
445+
type = QueryType.WRITE
446+
insert_return = False
447+
448+
def __init__(
449+
self,
450+
at: Timestamp,
451+
target_branch: Branch,
452+
**kwargs: Any,
453+
) -> None:
454+
super().__init__(**kwargs)
455+
self.at = at
456+
self.target_branch = target_branch
457+
self.source_branch_name = self.branch.name
458+
459+
async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
460+
self.params = {
461+
"at": self.at.to_string(),
462+
"target_branch": self.target_branch.name,
463+
"source_branch": self.source_branch_name,
464+
}
465+
query = """
466+
// ---------------------------
467+
// reset to times on target branch
468+
// ---------------------------
469+
CALL {
470+
OPTIONAL MATCH ()-[r_to {to: $at, branch: $target_branch}]-()
471+
SET r_to.to = NULL
472+
}
473+
// ---------------------------
474+
// reset from times on target branch
475+
// ---------------------------
476+
CALL {
477+
OPTIONAL MATCH ()-[r_from {from: $at, branch: $target_branch}]-()
478+
DELETE r_from
479+
}
480+
"""
481+
self.add_to_query(query=query)

backend/infrahub/core/merge.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def __init__(
4444
self.diff_coordinator = diff_coordinator
4545
self.diff_merger = diff_merger
4646
self.migrations: list[SchemaUpdateMigrationInfo] = []
47+
self._merge_at = Timestamp()
4748
self._graph_diff: Optional[BranchDiffer] = None
4849

4950
self._source_schema: Optional[SchemaBranch] = None
@@ -230,7 +231,6 @@ async def validate_graph(self) -> list[DataConflict]:
230231
async def merge(
231232
self,
232233
at: Optional[Union[str, Timestamp]] = None,
233-
conflict_resolution: Optional[dict[str, bool]] = None, # pylint: disable=unused-argument
234234
) -> None:
235235
"""Merge the current branch into main."""
236236
if self.source_branch.name == registry.default_branch:
@@ -252,10 +252,13 @@ async def merge(
252252

253253
# TODO need to find a way to properly communicate back to the user any issue that could come up during the merge
254254
# From the Graph or From the repositories
255-
at = Timestamp(at)
256-
await self.diff_merger.merge_graph(at=at)
255+
self._merge_at = Timestamp(at)
256+
await self.diff_merger.merge_graph(at=self._merge_at)
257257
await self.merge_repositories()
258258

259+
async def rollback(self) -> None:
260+
await self.diff_merger.rollback(at=self._merge_at)
261+
259262
async def merge_repositories(self) -> None:
260263
# Collect all Repositories in Main because we'll need the commit in Main for each one.
261264
repos_in_main_list = await NodeManager.query(schema=CoreRepository, db=self.db)

backend/infrahub/exceptions.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,11 @@ class HTTPServerTimeoutError(HTTPServerError):
317317

318318
class HTTPServerSSLError(HTTPServerError):
319319
HTTP_CODE = 503
320+
321+
322+
class MergeFailedError(Error):
323+
HTTP_CODE: int = 500
324+
325+
def __init__(self, branch_name: str) -> None:
326+
self.message = f"Failed to merge branch '{branch_name}'"
327+
super().__init__(self.message)

backend/infrahub/graphql/mutations/proposed_change.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ async def mutate_update( # pylint: disable=too-many-branches
140140
)
141141

142142
if updated_state == ProposedChangeState.MERGED:
143-
conflict_resolution: dict[str, bool] = {}
144143
source_branch = await Branch.get_by_name(db=dbt, name=proposed_change.source_branch.value)
145144
validations = await proposed_change.validations.get_peers(db=dbt)
146145
for validation in validations.values():
@@ -158,13 +157,10 @@ async def mutate_update( # pylint: disable=too-many-branches
158157
raise ValidationError(
159158
"Data conflicts found on branch and missing decisions about what branch to keep"
160159
)
161-
if check.conflicts.value:
162-
keep_source_value = check.keep_branch.value.value == "source"
163-
conflict_resolution[check.conflicts.value[0]["path"]] = keep_source_value
164160

165161
await context.service.workflow.execute_workflow(
166162
workflow=BRANCH_MERGE,
167-
parameters={"branch": source_branch.name, "conflict_resolution": conflict_resolution},
163+
parameters={"branch": source_branch.name},
168164
)
169165

170166
return proposed_change, result
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING
4+
from unittest.mock import patch
5+
6+
import pytest
7+
from infrahub_sdk.exceptions import GraphQLError
8+
9+
from infrahub.core.initialization import create_branch
10+
from infrahub.core.manager import NodeManager
11+
from infrahub.core.merge import BranchMerger
12+
from infrahub.core.node import Node
13+
from infrahub.services.adapters.cache.redis import RedisCache
14+
from tests.constants import TestKind
15+
from tests.helpers.schema import CAR_SCHEMA, load_schema
16+
from tests.helpers.test_app import TestInfrahubApp
17+
18+
if TYPE_CHECKING:
19+
from infrahub_sdk import InfrahubClient
20+
21+
from infrahub.core.branch.models import Branch
22+
from infrahub.database import InfrahubDatabase
23+
from tests.adapters.message_bus import BusSimulator
24+
25+
26+
BRANCH_MERGE = """
27+
mutation($branch: String!) {
28+
BranchMerge(data: { name: $branch }) {
29+
ok
30+
}
31+
}
32+
"""
33+
34+
35+
class BrokenBranchMerger:
36+
def __init__(self, *args, **kwargs) -> None:
37+
self.real_merger = BranchMerger(*args, **kwargs)
38+
39+
async def merge(self, at=None) -> None:
40+
await self.real_merger.merge(at=at)
41+
raise ValueError("This is broken on purpose")
42+
43+
async def rollback(self) -> None:
44+
await self.real_merger.rollback()
45+
46+
47+
class TestBranchMergeRollback(TestInfrahubApp):
48+
@pytest.fixture(scope="class")
49+
async def initial_dataset(
50+
self,
51+
db: InfrahubDatabase,
52+
initialize_registry: None,
53+
bus_simulator: BusSimulator,
54+
prefect_test_fixture: None,
55+
) -> dict[str, Node]:
56+
await load_schema(db, schema=CAR_SCHEMA)
57+
58+
bus_simulator.service.cache = RedisCache()
59+
60+
john = await Node.init(schema=TestKind.PERSON, db=db)
61+
await john.new(db=db, name="John", height=175, description="The famous Joe Doe")
62+
await john.save(db=db)
63+
kara = await Node.init(schema=TestKind.PERSON, db=db)
64+
await kara.new(db=db, name="Kara Thrace", height=165, description="Starbuck")
65+
await kara.save(db=db)
66+
murphy = await Node.init(schema=TestKind.PERSON, db=db)
67+
await murphy.new(db=db, name="Alex Murphy", height=185, description="Robocop")
68+
await murphy.save(db=db)
69+
omnicorp = await Node.init(schema=TestKind.MANUFACTURER, db=db)
70+
await omnicorp.new(db=db, name="Omnicorp", customers=[murphy])
71+
await omnicorp.save(db=db)
72+
cyberdyne = await Node.init(schema=TestKind.MANUFACTURER, db=db)
73+
await cyberdyne.new(db=db, name="Cyberdyne")
74+
await cyberdyne.save(db=db)
75+
76+
t_800 = await Node.init(schema=TestKind.CAR, db=db)
77+
await t_800.new(
78+
db=db,
79+
name="Cyberdyne systems model 101",
80+
color="Chrome",
81+
description="killing machine with secret heart of gold",
82+
owner=john,
83+
manufacturer=cyberdyne,
84+
)
85+
await t_800.save(db=db)
86+
ed_209 = await Node.init(schema=TestKind.CAR, db=db)
87+
await ed_209.new(
88+
db=db,
89+
name="ED-209",
90+
color="Chrome",
91+
description="still working on doing stairs",
92+
owner=murphy,
93+
manufacturer=omnicorp,
94+
)
95+
await ed_209.save(db=db)
96+
97+
return {
98+
"john": john,
99+
"kara": kara,
100+
"murphy": murphy,
101+
"omnicorp": omnicorp,
102+
"cyberdyne": cyberdyne,
103+
"t_800": t_800,
104+
"ed_209": ed_209,
105+
}
106+
107+
@pytest.fixture(scope="class")
108+
async def branch1(self, db: InfrahubDatabase) -> Branch:
109+
return await create_branch(db=db, branch_name="branch1")
110+
111+
@pytest.fixture(scope="class")
112+
async def branch1_data(
113+
self, db: InfrahubDatabase, initial_dataset: dict[str, Node], branch1: Branch
114+
) -> dict[str, Node]:
115+
kara_branch = await NodeManager.get_one(db=db, branch=branch1, id=initial_dataset["kara"].id)
116+
await kara_branch.delete(db=db)
117+
118+
sarah = await Node.init(schema=TestKind.PERSON, db=db, branch=branch1)
119+
await sarah.new(db=db, name="Sarah", height=161, description="no fate")
120+
await sarah.save(db=db)
121+
122+
t_800_branch = await NodeManager.get_one(db=db, branch=branch1, id=initial_dataset["t_800"].id)
123+
await t_800_branch.owner.update(db=db, data=sarah)
124+
await t_800_branch.save(db=db)
125+
126+
ocp_branch = await NodeManager.get_one(db=db, branch=branch1, id=initial_dataset["omnicorp"].id)
127+
ocp_branch.name.value = "Omni Consumer Products"
128+
await ocp_branch.save(db=db)
129+
130+
return {"sarah": sarah}
131+
132+
async def test_merge_branch_rollback(
133+
self,
134+
db: InfrahubDatabase,
135+
client: InfrahubClient,
136+
initial_dataset: dict[str, Node],
137+
branch1: Branch,
138+
branch1_data: dict[str, Node],
139+
) -> None:
140+
with patch("infrahub.core.branch.tasks.BranchMerger", new=BrokenBranchMerger):
141+
with pytest.raises(GraphQLError) as exc:
142+
await client.execute_graphql(query=BRANCH_MERGE, variables={"branch": branch1.name})
143+
144+
assert exc
145+
assert f"Failed to merge branch '{branch1.name}'" in exc.value.message
146+
147+
# check that the changes on the branch have all been rolled back
148+
kara_main = await NodeManager.get_one(db=db, id=initial_dataset["kara"].id)
149+
assert kara_main.id
150+
151+
sarah = await NodeManager.get_one(db=db, id=branch1_data["sarah"].id)
152+
assert sarah is None
153+
154+
t_800_main = await NodeManager.get_one(db=db, id=initial_dataset["t_800"].id)
155+
owner_peer = await t_800_main.owner.get_peer(db=db)
156+
assert owner_peer.id == initial_dataset["john"].id
157+
158+
ocp_main = await NodeManager.get_one(db=db, id=initial_dataset["omnicorp"].id)
159+
assert ocp_main.name.value == "Omnicorp"

0 commit comments

Comments
 (0)