Skip to content

Commit 666f97c

Browse files
committed
Send node mutation events after branch merge
1 parent a36bfd7 commit 666f97c

File tree

12 files changed

+244
-23
lines changed

12 files changed

+244
-23
lines changed

backend/infrahub/context.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,18 @@
55
from infrahub.core.branch import Branch
66

77

8-
class EventContext(BaseModel):
8+
class ParentEvent(BaseModel):
9+
id: str
910
name: str
1011

1112

13+
class EventContext(BaseModel):
14+
name: str = Field(..., description="The name of the event")
15+
id: str = Field(..., description="The ID of the event")
16+
parent_id: str | None = Field(default=None)
17+
ancestors: list[ParentEvent] = Field(default_factory=list)
18+
19+
1220
class BranchContext(BaseModel):
1321
name: str
1422
id: str | None = None
@@ -17,8 +25,15 @@ class BranchContext(BaseModel):
1725
class InfrahubContext(BaseModel):
1826
branch: BranchContext
1927
account: AccountSession
20-
events: list[EventContext] = Field(default_factory=list)
28+
event: EventContext | None = Field(default=None)
2129

2230
@classmethod
2331
def init(cls, branch: Branch, account: AccountSession) -> Self:
2432
return cls(branch=BranchContext(name=branch.name, id=str(branch.uuid)), account=account)
33+
34+
def set_event(self, name: str, id: str) -> None:
35+
if self.event:
36+
self.event.name = name
37+
self.event.id = id
38+
else:
39+
self.event = EventContext(name=name, id=id)

backend/infrahub/core/branch/tasks.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from infrahub.context import InfrahubContext # noqa: TC001 needed for prefect flow
1212
from infrahub.core import registry
1313
from infrahub.core.branch import Branch
14+
from infrahub.core.changelog.diff import DiffChangelogCollector
15+
from infrahub.core.constants import MutationAction
1416
from infrahub.core.diff.coordinator import DiffCoordinator
1517
from infrahub.core.diff.ipam_diff_parser import IpamDiffParser
1618
from infrahub.core.diff.merger.merger import DiffMerger
@@ -23,8 +25,9 @@
2325
from infrahub.core.validators.models.validate_migration import SchemaValidateMigrationData
2426
from infrahub.core.validators.tasks import schema_validate_migrations
2527
from infrahub.dependencies.registry import get_component_registry
26-
from infrahub.events.branch_action import BranchCreatedEvent, BranchDeletedEvent, BranchRebasedEvent
27-
from infrahub.events.models import EventMeta
28+
from infrahub.events.branch_action import BranchCreatedEvent, BranchDeletedEvent, BranchMergedEvent, BranchRebasedEvent
29+
from infrahub.events.models import EventMeta, InfrahubEvent
30+
from infrahub.events.node_action import NodeMutatedEvent
2831
from infrahub.exceptions import BranchNotFoundError, MergeFailedError, ValidationError
2932
from infrahub.graphql.mutations.models import BranchCreateModel # noqa: TC001
3033
from infrahub.log import get_log_data
@@ -169,7 +172,9 @@ async def merge_branch(branch: str, context: InfrahubContext, service: InfrahubS
169172
await add_tags(branches=[branch, registry.default_branch])
170173

171174
obj = await Branch.get_by_name(db=db, name=branch)
175+
default_branch = await registry.get_branch(db=db, branch=registry.default_branch)
172176
component_registry = get_component_registry()
177+
merge_event = BranchMergedEvent(meta=EventMeta.from_context(context=context, branch=obj))
173178

174179
merger: BranchMerger | None = None
175180
async with lock.registry.global_graph_lock():
@@ -187,13 +192,15 @@ async def merge_branch(branch: str, context: InfrahubContext, service: InfrahubS
187192
service=service,
188193
)
189194
try:
190-
await merger.merge()
195+
branch_diff = await merger.merge()
191196
except Exception as exc:
192197
log.exception("Merge failed, beginning rollback")
193198
await merger.rollback()
194199
raise MergeFailedError(branch_name=branch) from exc
195200
await merger.update_schema()
196201

202+
changelog_collector = DiffChangelogCollector(diff=branch_diff, branch=obj, db=db)
203+
node_events = changelog_collector.collect_changelogs()
197204
if merger and merger.migrations:
198205
errors = await schema_apply_migrations(
199206
message=SchemaApplyMigrationData(
@@ -242,6 +249,24 @@ async def merge_branch(branch: str, context: InfrahubContext, service: InfrahubS
242249
)
243250
await service.message_bus.send(message=message)
244251

252+
events: list[InfrahubEvent] = [merge_event]
253+
254+
for action, node_changelog in node_events:
255+
meta = EventMeta.from_parent(parent=merge_event)
256+
mutate_event = NodeMutatedEvent(
257+
kind=node_changelog.node_kind,
258+
node_id=node_changelog.node_id,
259+
data=node_changelog,
260+
action=MutationAction.from_diff_action(diff_action=action),
261+
fields=node_changelog.updated_fields,
262+
meta=meta,
263+
)
264+
mutate_event.set_context_branch(branch=default_branch)
265+
events.append(mutate_event)
266+
267+
for event in events:
268+
await service.event.send(event=event)
269+
245270

246271
@flow(name="branch-delete", flow_run_name="Delete branch {branch}")
247272
async def delete_branch(branch: str, context: InfrahubContext, service: InfrahubServices) -> None:

backend/infrahub/core/changelog/diff.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,15 @@ def _populate_diff_nodes(self) -> None:
5050
def get_node(self, node_id: str) -> NodeInDiff:
5151
return self._diff_nodes[node_id]
5252

53+
def get_peer_kind(self, peer_id: str, node_kind: str, relationship_name: str) -> str:
54+
"""If the peer kind doesn't exist in the diff use the peer kind from the schema"""
55+
try:
56+
return self.get_node(node_id=peer_id).kind
57+
except KeyError:
58+
schema = self._db.schema.get(node_kind, branch=self._branch, duplicate=False)
59+
rel_schema = schema.get_relationship(name=relationship_name)
60+
return rel_schema.peer
61+
5362
def _process_node(self, node: EnrichedDiffNode) -> NodeChangelog:
5463
node_changelog = NodeChangelog(node_id=node.uuid, node_kind=node.kind, display_label=node.label)
5564
schema = self._db.schema.get(node_changelog.node_kind, branch=self._branch, duplicate=False)
@@ -64,8 +73,15 @@ def _process_node(self, node: EnrichedDiffNode) -> NodeChangelog:
6473
def _process_node_attribute(
6574
self, node: NodeChangelog, attribute: EnrichedDiffAttribute, schema: MainSchemaTypes
6675
) -> None:
67-
schema_attribute = schema.get_attribute(name=attribute.name)
68-
changelog_attribute = AttributeChangelog(name=attribute.name, kind=schema_attribute.kind)
76+
try:
77+
schema_attribute = schema.get_attribute(name=attribute.name)
78+
attribute_kind = schema_attribute.kind
79+
except ValueError:
80+
# This would currently happen if there has been a schema migration as part of the merge
81+
# then we don't have access to the attribute kind
82+
attribute_kind = "n/a"
83+
84+
changelog_attribute = AttributeChangelog(name=attribute.name, kind=attribute_kind)
6985
for attr_property in attribute.properties:
7086
match attr_property.property_type:
7187
case DatabaseEdgeType.HAS_VALUE:
@@ -117,10 +133,18 @@ def _process_node_cardinality_one_relationship(
117133
case DatabaseEdgeType.IS_RELATED:
118134
if rel_prop.new_value:
119135
changelog_rel.peer_id = rel_prop.new_value
120-
changelog_rel.peer_kind = self.get_node(node_id=rel_prop.new_value).kind
136+
changelog_rel.peer_kind = self.get_peer_kind(
137+
peer_id=rel_prop.new_value,
138+
node_kind=node.node_kind,
139+
relationship_name=relationship.name,
140+
)
121141
if rel_prop.previous_value:
122142
changelog_rel.peer_id_previous = rel_prop.previous_value
123-
changelog_rel.peer_kind_previous = self.get_node(node_id=rel_prop.previous_value).kind
143+
changelog_rel.peer_kind_previous = self.get_peer_kind(
144+
peer_id=rel_prop.previous_value,
145+
node_kind=node.node_kind,
146+
relationship_name=relationship.name,
147+
)
124148
case DatabaseEdgeType.IS_PROTECTED:
125149
changelog_rel.add_property(
126150
name="is_protected",
@@ -161,7 +185,11 @@ def _process_node_cardinality_many_relationship(
161185
changelog_rel = RelationshipCardinalityManyChangelog(name=relationship.name)
162186
for peer in relationship.relationships:
163187
peer_log = RelationshipPeerChangelog(
164-
peer_id=peer.peer_id, peer_kind=self.get_node(node_id=peer.peer_id).kind, peer_status=peer.action
188+
peer_id=peer.peer_id,
189+
peer_kind=self.get_peer_kind(
190+
peer_id=peer.peer_id, node_kind=node.node_kind, relationship_name=relationship.name
191+
),
192+
peer_status=peer.action,
165193
)
166194
for peer_prop in peer.properties:
167195
match peer_prop.property_type:

backend/infrahub/core/constants/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,18 @@ class MutationAction(InfrahubStringEnum):
180180
UPDATED = "updated"
181181
UNDEFINED = "undefined"
182182

183+
@classmethod
184+
def from_diff_action(cls, diff_action: DiffAction) -> MutationAction:
185+
match diff_action:
186+
case DiffAction.ADDED:
187+
return MutationAction.CREATED
188+
case DiffAction.REMOVED:
189+
return MutationAction.DELETED
190+
case DiffAction.UPDATED:
191+
return MutationAction.UPDATED
192+
case DiffAction.UNCHANGED:
193+
return MutationAction.UNDEFINED
194+
183195

184196
class PathResourceType(InfrahubStringEnum):
185197
SCHEMA = "schema"

backend/infrahub/events/branch_action.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,25 @@ def event_name(self) -> str:
7070
return f"{EVENT_NAMESPACE}.branch.created"
7171

7272

73+
class BranchMergedEvent(InfrahubEvent):
74+
"""Event generated when a branch has been merged"""
75+
76+
def get_resource(self) -> dict[str, str]:
77+
return {
78+
"prefect.resource.id": f"infrahub.branch.{self.meta.get_branch_id()}",
79+
"infrahub.node.kind": "Branch",
80+
"infrahub.node.id": self.meta.get_branch_id(),
81+
"infrahub.node.label": self.meta.context.branch.name,
82+
}
83+
84+
def get_messages(self) -> list[InfrahubMessage]:
85+
return []
86+
87+
@computed_field
88+
def event_name(self) -> str:
89+
return f"{EVENT_NAMESPACE}.branch.merged"
90+
91+
7392
class BranchRebasedEvent(InfrahubEvent):
7493
"""Event generated when a branch has been rebased"""
7594

backend/infrahub/events/models.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from __future__ import annotations
22

3-
from typing import Any, cast, final
3+
from copy import deepcopy
4+
from typing import Any, Self, cast, final
45
from uuid import UUID, uuid4
56

6-
from pydantic import BaseModel, Field, computed_field
7+
from pydantic import BaseModel, Field, PrivateAttr, computed_field, model_validator
78

89
from infrahub import __version__
910
from infrahub.auth import AccountSession, AuthType
@@ -45,6 +46,7 @@ class EventMeta(BaseModel):
4546

4647
parent: UUID | None = Field(default=None, description="The UUID of the parent event if applicable")
4748
ancestors: list[ParentEvent] = Field(default_factory=list, description="Any event used to trigger this event")
49+
_created_with_context: bool = PrivateAttr(default=False)
4850

4951
def get_branch_id(self) -> str:
5052
if self.context.branch.id:
@@ -123,10 +125,20 @@ def from_parent(cls, parent: InfrahubEvent) -> EventMeta:
123125
initiator_id=parent.meta.initiator_id,
124126
account_id=parent.meta.account_id,
125127
level=parent.meta.level + 1,
126-
context=parent.meta.context,
128+
context=deepcopy(parent.meta.context),
127129
ancestors=[ParentEvent(id=parent.get_id(), name=parent.get_name())] + parent.meta.ancestors,
128130
)
129131

132+
@classmethod
133+
def from_context(cls, context: InfrahubContext, branch: Branch | None = None) -> EventMeta:
134+
# Create a copy of the context so local changes aren't brought back to a parent object
135+
meta = cls(context=deepcopy(context))
136+
meta._created_with_context = True
137+
if branch:
138+
meta.context.branch.name = branch.name
139+
meta.context.branch.id = str(branch.get_uuid())
140+
return meta
141+
130142

131143
class InfrahubEvent(BaseModel):
132144
meta: EventMeta = Field(..., description="Metadata for the event")
@@ -180,3 +192,10 @@ def get_message_meta(self) -> Meta:
180192
@computed_field
181193
def event_name(self) -> str:
182194
raise NotImplementedError("The event name has not been defined")
195+
196+
@model_validator(mode="after")
197+
def update_context(self) -> Self:
198+
"""Update the context object using this event provided that the meta data was created with a context."""
199+
if self.meta._created_with_context:
200+
self.meta.context.set_event(self.get_name(), id=self.get_id())
201+
return self

backend/infrahub/events/node_action.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from pydantic import Field, computed_field
44

5+
from infrahub.core.branch import Branch
56
from infrahub.core.changelog.models import NodeChangelog
67
from infrahub.core.constants import MutationAction
78
from infrahub.message_bus import InfrahubMessage
@@ -93,6 +94,10 @@ def get_messages(self) -> list[InfrahubMessage]:
9394
# )
9495
]
9596

97+
def set_context_branch(self, branch: Branch) -> None:
98+
self.meta.context.branch.id = str(branch.get_uuid())
99+
self.meta.context.branch.name = branch.name
100+
96101

97102
class NodeCreatedEvent(NodeMutatedEvent):
98103
action: MutationAction = MutationAction.CREATED

backend/infrahub/services/adapters/event/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ async def _send_bus(self, event: InfrahubEvent) -> None:
3030

3131
async def _send_prefect(self, event: InfrahubEvent) -> None:
3232
emit_event(
33+
id=event.meta.id,
3334
event=event.get_name(),
3435
resource=event.get_resource(),
3536
related=event.get_related(),

backend/tests/helpers/utils.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import time
2+
from pathlib import Path
3+
4+
import httpx
15
import pytest
26
from testcontainers.core.container import DockerContainer
37
from testcontainers.core.waiting_utils import wait_for_logs
@@ -38,17 +42,38 @@ def start_prefect_server_container(
3842
if not INFRAHUB_USE_TEST_CONTAINERS:
3943
return None
4044

45+
prefect_base = Path(Path(__file__).parent.resolve() / "./../../infrahub/prefect_server")
4146
container = (
4247
DockerContainer(image="prefecthq/prefect:3.1.14-python3.12")
43-
.with_command("prefect server start --host 0.0.0.0 --ui")
48+
.with_command("uvicorn --host 0.0.0.0 --port 4200 --factory prefect_server.app:create_infrahub_prefect")
4449
.with_exposed_ports(PORT_PREFECT)
50+
.with_volume_mapping(host=str(prefect_base), container="/opt/prefect/prefect_server", mode="ro")
51+
.with_env(key="PREFECT_SERVER_SERVICES_EVENT_PERSISTER_FLUSH_INTERVAL", value="1")
4552
)
4653

4754
def cleanup() -> None:
4855
container.stop()
4956

5057
container.start()
51-
wait_for_logs(container, "Configure Prefect to communicate with the server")
58+
59+
mapped_port = get_exposed_port(container, PORT_PREFECT)
60+
# As our entrypoint doesn't print out any logs on startup we can't "wait for logs"
61+
wait_for_prefect(port=mapped_port)
5262
request.addfinalizer(cleanup)
5363

54-
return {PORT_PREFECT: get_exposed_port(container, PORT_PREFECT)}
64+
return {PORT_PREFECT: mapped_port}
65+
66+
67+
url = "http://localhost:52879/api/admin/version" # Replace with your target URL
68+
69+
70+
def wait_for_prefect(port: int) -> None:
71+
for _ in range(120):
72+
try:
73+
response = httpx.get(f"http://localhost:{port}/api/admin/version", timeout=1)
74+
if response.status_code == 200:
75+
return
76+
except httpx.HTTPError:
77+
time.sleep(1)
78+
79+
pytest.fail(reason="Prefect didn't start in an orderly fashion")

0 commit comments

Comments
 (0)