11from __future__ import annotations
22
33from typing import Any
4+ from uuid import uuid4
45
56import pydantic
67from prefect import flow , get_run_logger
1112from infrahub .context import InfrahubContext # noqa: TC001 needed for prefect flow
1213from infrahub .core import registry
1314from infrahub .core .branch import Branch
14- from infrahub .core .changelog .diff import DiffChangelogCollector
15+ from infrahub .core .changelog .diff import DiffChangelogCollector , MigrationTracker
1516from infrahub .core .constants import MutationAction
1617from infrahub .core .diff .coordinator import DiffCoordinator
1718from infrahub .core .diff .ipam_diff_parser import IpamDiffParser
1819from infrahub .core .diff .merger .merger import DiffMerger
19- from infrahub .core .diff .model .path import BranchTrackingId
20+ from infrahub .core .diff .model .path import BranchTrackingId , EnrichedDiffRoot , EnrichedDiffRootMetadata
2021from infrahub .core .diff .repository .repository import DiffRepository
2122from infrahub .core .merge import BranchMerger
2223from infrahub .core .migrations .schema .models import SchemaApplyMigrationData
2324from infrahub .core .migrations .schema .tasks import schema_apply_migrations
25+ from infrahub .core .timestamp import Timestamp
2426from infrahub .core .validators .determiner import ConstraintValidatorDeterminer
2527from infrahub .core .validators .models .validate_migration import SchemaValidateMigrationData
2628from infrahub .core .validators .tasks import schema_validate_migrations
@@ -54,6 +56,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
5456 diff_repository = await component_registry .get_component (DiffRepository , db = db , branch = obj )
5557 diff_coordinator = await component_registry .get_component (DiffCoordinator , db = db , branch = obj )
5658 diff_merger = await component_registry .get_component (DiffMerger , db = db , branch = obj )
59+ initial_from_time = Timestamp (obj .get_branched_from ())
5760 merger = BranchMerger (
5861 db = db ,
5962 diff_coordinator = diff_coordinator ,
@@ -62,7 +65,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
6265 source_branch = obj ,
6366 service = service ,
6467 )
65- diff_repository = await component_registry . get_component ( DiffRepository , db = db , branch = obj )
68+
6669 enriched_diff_metadata = await diff_coordinator .update_branch_diff (base_branch = base_branch , diff_branch = obj )
6770 async for _ in diff_repository .get_all_conflicts_for_diff (
6871 diff_branch_name = enriched_diff_metadata .diff_branch_name , diff_id = enriched_diff_metadata .uuid
@@ -97,7 +100,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
97100 raise ValidationError (",\n " .join (error_messages ))
98101
99102 schema_in_main_before = merger .destination_schema .duplicate ()
100-
103+ migrations = []
101104 async with lock .registry .global_graph_lock ():
102105 async with db .start_transaction () as dbt :
103106 await obj .rebase (db = dbt )
@@ -134,6 +137,14 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
134137 for error in errors :
135138 log .error (error )
136139
140+ default_branch_diff = await _get_diff_root (
141+ diff_coordinator = diff_coordinator ,
142+ enriched_diff_metadata = enriched_diff_metadata ,
143+ diff_repository = diff_repository ,
144+ base_branch = base_branch ,
145+ target_from = initial_from_time ,
146+ )
147+
137148 # -------------------------------------------------------------
138149 # Trigger the reconciliation of IPAM data after the rebase
139150 # -------------------------------------------------------------
@@ -156,14 +167,26 @@ async def rebase_branch(branch: str, context: InfrahubContext, service: Infrahub
156167 # -------------------------------------------------------------
157168 # Generate an event to indicate that a branch has been rebased
158169 # -------------------------------------------------------------
159- # TODO Add account information
160- await service .event .send (
161- event = BranchRebasedEvent (
162- branch_name = obj .name ,
163- branch_id = str (obj .uuid ),
164- meta = EventMeta .from_context (context = context , branch = registry .get_global_branch ()),
165- )
170+ rebase_event = BranchRebasedEvent (
171+ branch_name = obj .name , branch_id = str (obj .uuid ), meta = EventMeta (branch = obj , context = context )
172+ )
173+ events : list [InfrahubEvent ] = [rebase_event ]
174+ changelog_collector = DiffChangelogCollector (
175+ diff = default_branch_diff , branch = obj , db = db , migration_tracker = MigrationTracker (migrations = migrations )
166176 )
177+ for action , node_changelog in changelog_collector .collect_changelogs ():
178+ mutate_event = NodeMutatedEvent (
179+ kind = node_changelog .node_kind ,
180+ node_id = node_changelog .node_id ,
181+ data = node_changelog ,
182+ action = MutationAction .from_diff_action (diff_action = action ),
183+ fields = node_changelog .updated_fields ,
184+ meta = EventMeta .from_parent (parent = rebase_event , branch = obj ),
185+ )
186+ events .append (mutate_event )
187+
188+ for event in events :
189+ await service .event .send (event )
167190
168191
169192@flow (name = "branch-merge" , flow_run_name = "Merge branch {branch} into main" )
@@ -258,7 +281,7 @@ async def merge_branch(branch: str, context: InfrahubContext, service: InfrahubS
258281 events : list [InfrahubEvent ] = [merge_event ]
259282
260283 for action , node_changelog in node_events :
261- meta = EventMeta .from_parent (parent = merge_event )
284+ meta = EventMeta .from_parent (parent = merge_event , branch = default_branch )
262285 mutate_event = NodeMutatedEvent (
263286 kind = node_changelog .node_kind ,
264287 node_id = node_changelog .node_id ,
@@ -267,7 +290,6 @@ async def merge_branch(branch: str, context: InfrahubContext, service: InfrahubS
267290 fields = node_changelog .updated_fields ,
268291 meta = meta ,
269292 )
270- mutate_event .set_context_branch (branch = default_branch )
271293 events .append (mutate_event )
272294
273295 for event in events :
@@ -364,3 +386,26 @@ async def create_branch(model: BranchCreateModel, context: InfrahubContext, serv
364386 context = context ,
365387 parameters = {"branch" : obj .name , "branch_id" : str (obj .uuid )},
366388 )
389+
390+
391+ async def _get_diff_root (
392+ diff_coordinator : DiffCoordinator ,
393+ enriched_diff_metadata : EnrichedDiffRootMetadata ,
394+ diff_repository : DiffRepository ,
395+ base_branch : Branch ,
396+ target_from : Timestamp ,
397+ ) -> EnrichedDiffRoot :
398+ default_branch_diff = await diff_coordinator .create_or_update_arbitrary_timeframe_diff (
399+ base_branch = base_branch ,
400+ diff_branch = base_branch ,
401+ from_time = target_from ,
402+ to_time = enriched_diff_metadata .to_time ,
403+ name = str (uuid4 ()),
404+ )
405+ # make sure we have the actual diff with data and not just the metadata
406+ if not isinstance (default_branch_diff , EnrichedDiffRoot ):
407+ default_branch_diff = await diff_repository .get_one (
408+ diff_branch_name = base_branch .name , diff_id = default_branch_diff .uuid
409+ )
410+
411+ return default_branch_diff
0 commit comments