Skip to content

Commit 4b1598d

Browse files
authored
Merge pull request #5865 from opsmill/pog-rebase-branch-event-IFC-1046
Send node mutation events with rebase
2 parents 9bcb57c + 382a43b commit 4b1598d

File tree

8 files changed

+210
-25
lines changed

8 files changed

+210
-25
lines changed

backend/infrahub/core/branch/tasks.py

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from typing import Any
4+
from uuid import uuid4
45

56
import pydantic
67
from prefect import flow, get_run_logger
@@ -11,16 +12,17 @@
1112
from infrahub.context import InfrahubContext # noqa: TC001 needed for prefect flow
1213
from infrahub.core import registry
1314
from infrahub.core.branch import Branch
14-
from infrahub.core.changelog.diff import DiffChangelogCollector
15+
from infrahub.core.changelog.diff import DiffChangelogCollector, MigrationTracker
1516
from infrahub.core.constants import MutationAction
1617
from infrahub.core.diff.coordinator import DiffCoordinator
1718
from infrahub.core.diff.ipam_diff_parser import IpamDiffParser
1819
from infrahub.core.diff.merger.merger import DiffMerger
19-
from infrahub.core.diff.model.path import BranchTrackingId
20+
from infrahub.core.diff.model.path import BranchTrackingId, EnrichedDiffRoot, EnrichedDiffRootMetadata
2021
from infrahub.core.diff.repository.repository import DiffRepository
2122
from infrahub.core.merge import BranchMerger
2223
from infrahub.core.migrations.schema.models import SchemaApplyMigrationData
2324
from infrahub.core.migrations.schema.tasks import schema_apply_migrations
25+
from infrahub.core.timestamp import Timestamp
2426
from infrahub.core.validators.determiner import ConstraintValidatorDeterminer
2527
from infrahub.core.validators.models.validate_migration import SchemaValidateMigrationData
2628
from infrahub.core.validators.tasks import schema_validate_migrations
@@ -54,6 +56,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
5456
diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj)
5557
diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=obj)
5658
diff_merger = await component_registry.get_component(DiffMerger, db=db, branch=obj)
59+
initial_from_time = Timestamp(obj.get_branched_from())
5760
merger = BranchMerger(
5861
db=db,
5962
diff_coordinator=diff_coordinator,
@@ -62,7 +65,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
6265
source_branch=obj,
6366
service=service,
6467
)
65-
diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj)
68+
6669
enriched_diff_metadata = await diff_coordinator.update_branch_diff(base_branch=base_branch, diff_branch=obj)
6770
async for _ in diff_repository.get_all_conflicts_for_diff(
6871
diff_branch_name=enriched_diff_metadata.diff_branch_name, diff_id=enriched_diff_metadata.uuid
@@ -97,7 +100,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
97100
raise ValidationError(",\n".join(error_messages))
98101

99102
schema_in_main_before = merger.destination_schema.duplicate()
100-
103+
migrations = []
101104
async with lock.registry.global_graph_lock():
102105
async with db.start_transaction() as dbt:
103106
await obj.rebase(db=dbt)
@@ -134,6 +137,14 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
134137
for error in errors:
135138
log.error(error)
136139

140+
default_branch_diff = await _get_diff_root(
141+
diff_coordinator=diff_coordinator,
142+
enriched_diff_metadata=enriched_diff_metadata,
143+
diff_repository=diff_repository,
144+
base_branch=base_branch,
145+
target_from=initial_from_time,
146+
)
147+
137148
# -------------------------------------------------------------
138149
# Trigger the reconciliation of IPAM data after the rebase
139150
# -------------------------------------------------------------
@@ -156,14 +167,26 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
156167
# -------------------------------------------------------------
157168
# Generate an event to indicate that a branch has been rebased
158169
# -------------------------------------------------------------
159-
# TODO Add account information
160-
await service.event.send(
161-
event=BranchRebasedEvent(
162-
branch_name=obj.name,
163-
branch_id=str(obj.uuid),
164-
meta=EventMeta.from_context(context=context, branch=registry.get_global_branch()),
165-
)
170+
rebase_event = BranchRebasedEvent(
171+
branch_name=obj.name, branch_id=str(obj.uuid), meta=EventMeta(branch=obj, context=context)
172+
)
173+
events: list[InfrahubEvent] = [rebase_event]
174+
changelog_collector = DiffChangelogCollector(
175+
diff=default_branch_diff, branch=obj, db=db, migration_tracker=MigrationTracker(migrations=migrations)
166176
)
177+
for action, node_changelog in changelog_collector.collect_changelogs():
178+
mutate_event = NodeMutatedEvent(
179+
kind=node_changelog.node_kind,
180+
node_id=node_changelog.node_id,
181+
data=node_changelog,
182+
action=MutationAction.from_diff_action(diff_action=action),
183+
fields=node_changelog.updated_fields,
184+
meta=EventMeta.from_parent(parent=rebase_event, branch=obj),
185+
)
186+
events.append(mutate_event)
187+
188+
for event in events:
189+
await service.event.send(event)
167190

168191

169192
@flow(name="branch-merge", flow_run_name="Merge branch {branch} into main")
@@ -258,7 +281,7 @@ async def merge_branch(branch: str, context: InfrahubContext, service: InfrahubS
258281
events: list[InfrahubEvent] = [merge_event]
259282

260283
for action, node_changelog in node_events:
261-
meta = EventMeta.from_parent(parent=merge_event)
284+
meta = EventMeta.from_parent(parent=merge_event, branch=default_branch)
262285
mutate_event = NodeMutatedEvent(
263286
kind=node_changelog.node_kind,
264287
node_id=node_changelog.node_id,
@@ -267,7 +290,6 @@ async def merge_branch(branch: str, context: InfrahubContext, service: InfrahubS
267290
fields=node_changelog.updated_fields,
268291
meta=meta,
269292
)
270-
mutate_event.set_context_branch(branch=default_branch)
271293
events.append(mutate_event)
272294

273295
for event in events:
@@ -364,3 +386,26 @@ async def create_branch(model: BranchCreateModel, context: InfrahubContext, serv
364386
context=context,
365387
parameters={"branch": obj.name, "branch_id": str(obj.uuid)},
366388
)
389+
390+
391+
async def _get_diff_root(
392+
diff_coordinator: DiffCoordinator,
393+
enriched_diff_metadata: EnrichedDiffRootMetadata,
394+
diff_repository: DiffRepository,
395+
base_branch: Branch,
396+
target_from: Timestamp,
397+
) -> EnrichedDiffRoot:
398+
default_branch_diff = await diff_coordinator.create_or_update_arbitrary_timeframe_diff(
399+
base_branch=base_branch,
400+
diff_branch=base_branch,
401+
from_time=target_from,
402+
to_time=enriched_diff_metadata.to_time,
403+
name=str(uuid4()),
404+
)
405+
# make sure we have the actual diff with data and not just the metadata
406+
if not isinstance(default_branch_diff, EnrichedDiffRoot):
407+
default_branch_diff = await diff_repository.get_one(
408+
diff_branch_name=base_branch.name, diff_id=default_branch_diff.uuid
409+
)
410+
411+
return default_branch_diff

backend/infrahub/core/changelog/diff.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
EnrichedDiffRelationship,
2727
EnrichedDiffRoot,
2828
)
29+
from infrahub.core.models import SchemaUpdateMigrationInfo
2930
from infrahub.core.schema import MainSchemaTypes
3031
from infrahub.database import InfrahubDatabase
3132

@@ -38,11 +39,18 @@ class NodeInDiff:
3839

3940

4041
class DiffChangelogCollector:
41-
def __init__(self, diff: EnrichedDiffRoot, branch: Branch, db: InfrahubDatabase) -> None:
42+
def __init__(
43+
self,
44+
diff: EnrichedDiffRoot,
45+
branch: Branch,
46+
db: InfrahubDatabase,
47+
migration_tracker: MigrationTracker | None = None,
48+
) -> None:
4249
self._diff = diff
4350
self._branch = branch
4451
self._db = db
4552
self._diff_nodes: dict[str, NodeInDiff]
53+
self.migration = migration_tracker or MigrationTracker()
4654

4755
def _populate_diff_nodes(self) -> None:
4856
self._diff_nodes = {
@@ -83,14 +91,16 @@ def _process_node_attribute(
8391
# then we don't have access to the attribute kind
8492
attribute_kind = "n/a"
8593

86-
changelog_attribute = AttributeChangelog(name=attribute.name, kind=attribute_kind)
94+
changelog_attribute = AttributeChangelog(
95+
name=self.migration.get_attribute_name(node=node, attribute=attribute), kind=attribute_kind
96+
)
8797
for attr_property in attribute.properties:
8898
match attr_property.property_type:
8999
case DatabaseEdgeType.HAS_VALUE:
90100
# TODO deserialize correct value type from string
91101
if _keep_branch_update(diff_property=attr_property):
92-
changelog_attribute.value = attr_property.new_value
93-
changelog_attribute.value_previous = attr_property.previous_value
102+
changelog_attribute.set_value(value=attr_property.new_value)
103+
changelog_attribute.set_value_previous(value=attr_property.previous_value)
94104
case DatabaseEdgeType.IS_PROTECTED:
95105
if _keep_branch_update(diff_property=attr_property):
96106
changelog_attribute.add_property(
@@ -243,3 +253,31 @@ def _keep_branch_update(diff_property: EnrichedDiffProperty) -> bool:
243253
if diff_property.conflict and diff_property.conflict.selected_branch == ConflictSelection.BASE_BRANCH:
244254
return False
245255
return True
256+
257+
258+
class MigrationTracker:
259+
"""Keeps track of schema updates that happened as part of a migration"""
260+
261+
def __init__(self, migrations: list[SchemaUpdateMigrationInfo] | None = None) -> None:
262+
# A dictionary of Node kind, previous attribute name and new attribute
263+
# {"TestPerson": {"old_attribute_name": "new_attribute_name"}}
264+
self._migrations_attribute_map: dict[str, dict[str, str]] = {}
265+
266+
migrations = migrations or []
267+
for migration in migrations:
268+
if migration.migration_name == "attribute.name.update":
269+
if migration.path.schema_kind not in self._migrations_attribute_map:
270+
self._migrations_attribute_map[migration.path.schema_kind] = {}
271+
if migration.path.property_name and migration.path.field_name:
272+
self._migrations_attribute_map[migration.path.schema_kind][migration.path.property_name] = (
273+
migration.path.field_name
274+
)
275+
276+
def get_attribute_name(self, node: NodeChangelog, attribute: EnrichedDiffAttribute) -> str:
277+
"""Return the current name of the requested attribute"""
278+
if node.node_kind not in self._migrations_attribute_map:
279+
return attribute.name
280+
if attribute.name not in self._migrations_attribute_map[node.node_kind]:
281+
return attribute.name
282+
283+
return self._migrations_attribute_map[node.node_kind][attribute.name]

backend/infrahub/core/changelog/models.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,18 @@ def has_updates(self) -> bool:
7575
return True
7676
return False
7777

78+
def set_value(self, value: Any) -> None:
79+
if isinstance(value, str) and value == NULL_VALUE:
80+
self.value = None
81+
return
82+
self.value = value
83+
84+
def set_value_previous(self, value: Any) -> None:
85+
if isinstance(value, str) and value == NULL_VALUE:
86+
self.value_previous = None
87+
return
88+
self.value_previous = value
89+
7890
@field_validator("value", "value_previous")
7991
@classmethod
8092
def convert_null_values(cls, value: Any) -> Any:

backend/infrahub/events/models.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,20 +112,25 @@ def with_dummy_context(cls, branch: Branch) -> EventMeta:
112112
)
113113

114114
@classmethod
115-
def from_parent(cls, parent: InfrahubEvent) -> EventMeta:
115+
def from_parent(cls, parent: InfrahubEvent, branch: Branch | None = None) -> EventMeta:
116116
"""Create the metadata from an existing event
117117
118118
Note that this action will modify the existing event to indicate that children might be attached to the event
119119
"""
120120
parent.meta.has_children = True
121+
context = deepcopy(parent.meta.context)
122+
if branch:
123+
context.branch.name = branch.name
124+
context.branch.id = str(branch.get_uuid())
125+
121126
return cls(
122127
parent=parent.meta.id,
123128
branch=parent.meta.branch,
124129
request_id=parent.meta.request_id,
125130
initiator_id=parent.meta.initiator_id,
126131
account_id=parent.meta.account_id,
127132
level=parent.meta.level + 1,
128-
context=deepcopy(parent.meta.context),
133+
context=context,
129134
ancestors=[ParentEvent(id=parent.get_id(), name=parent.get_name())] + parent.meta.ancestors,
130135
)
131136

backend/infrahub/events/node_action.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from pydantic import Field, computed_field
44

5-
from infrahub.core.branch import Branch
65
from infrahub.core.changelog.models import NodeChangelog
76
from infrahub.core.constants import MutationAction
87
from infrahub.message_bus import InfrahubMessage
@@ -95,10 +94,6 @@ def get_messages(self) -> list[InfrahubMessage]:
9594
# )
9695
]
9796

98-
def set_context_branch(self, branch: Branch) -> None:
99-
self.meta.context.branch.id = str(branch.get_uuid())
100-
self.meta.context.branch.name = branch.name
101-
10297

10398
class NodeCreatedEvent(NodeMutatedEvent):
10499
action: MutationAction = MutationAction.CREATED

backend/infrahub/graphql/types/event.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ class EventTypeFilter(InputObjectType):
6464
branch_merged = Field(
6565
BranchEventTypeFilter, required=False, description="Filters specific to infrahub.branch.merged events"
6666
)
67+
branch_rebased = Field(
68+
BranchEventTypeFilter, required=False, description="Filters specific to infrahub.branch.rebased events"
69+
)
6770

6871

6972
# ---------------------------------------

backend/infrahub/task_manager/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,13 @@ def add_event_type_filter(
135135
if branches:
136136
self.resource = EventResourceFilter(labels=ResourceSpecification({"infrahub.branch.name": branches}))
137137

138+
if branch_rebased := event_type_filter.get("branch_rebased"):
139+
branches = branch_rebased.get("branches") or []
140+
if "infrahub.branch.created" not in event_type:
141+
event_type.append("infrahub.branch.rebased")
142+
if branches:
143+
self.resource = EventResourceFilter(labels=ResourceSpecification({"infrahub.branch.name": branches}))
144+
138145
if event_type:
139146
self.event = EventNameFilter(name=event_type)
140147

0 commit comments

Comments
 (0)