Skip to content

Commit 1597160

Browse files
committed
Add new flow to run migrations on a branch
This is done by instantiating a `MigrationRunner` which will take care of determining what migrations need to run to then run them if any.
1 parent e290fcd commit 1597160

File tree

8 files changed

+168
-136
lines changed

8 files changed

+168
-136
lines changed

backend/infrahub/cli/db.py

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
from rich.table import Table
1919

2020
from infrahub import config
21+
from infrahub.auth import AccountSession, AuthType
22+
from infrahub.context import InfrahubContext
2123
from infrahub.core import registry
2224
from infrahub.core.branch import Branch
23-
from infrahub.core.branch.enums import BranchStatus
25+
from infrahub.core.branch.tasks import rebase_branch
2426
from infrahub.core.constants import GLOBAL_BRANCH_NAME
2527
from infrahub.core.graph import GRAPH_VERSION
2628
from infrahub.core.graph.constraints import ConstraintManagerBase, ConstraintManagerMemgraph, ConstraintManagerNeo4j
@@ -40,7 +42,6 @@
4042
from infrahub.core.migrations.graph import get_graph_migrations, get_migration_by_number
4143
from infrahub.core.migrations.schema.models import SchemaApplyMigrationData
4244
from infrahub.core.migrations.schema.tasks import schema_apply_migrations
43-
from infrahub.core.migrations.shared import MigrationWithRebase
4445
from infrahub.core.schema import SchemaRoot, core_models, internal_schema
4546
from infrahub.core.schema.definitions.deprecated import deprecated_models
4647
from infrahub.core.schema.manager import SchemaManager
@@ -49,6 +50,7 @@
4950
from infrahub.database import DatabaseType
5051
from infrahub.database.memgraph import IndexManagerMemgraph
5152
from infrahub.database.neo4j import IndexManagerNeo4j
53+
from infrahub.exceptions import ValidationError
5254

5355
from .constants import ERROR_BADGE, FAILED_BADGE, SUCCESS_BADGE
5456
from .db_commands.check_inheritance import check_inheritance
@@ -63,7 +65,12 @@ def get_timestamp_string() -> str:
6365

6466
if TYPE_CHECKING:
6567
from infrahub.cli.context import CliContext
66-
from infrahub.core.migrations.shared import ArbitraryMigration, GraphMigration, InternalSchemaMigration
68+
from infrahub.core.migrations.shared import (
69+
ArbitraryMigration,
70+
GraphMigration,
71+
InternalSchemaMigration,
72+
MigrationWithRebase,
73+
)
6774
from infrahub.database import InfrahubDatabase
6875
from infrahub.database.index import IndexManagerBase
6976

@@ -354,51 +361,30 @@ async def migrate_database(
354361
return True
355362

356363

357-
async def rebase_and_migrate_branches(db: InfrahubDatabase, current_graph_version: int) -> bool:
358-
"""Only applies migrations that aim at rebasing branches."""
359-
branches = [
360-
b
361-
for b in await Branch.get_list(db=db)
362-
if b.name not in [registry.default_branch, GLOBAL_BRANCH_NAME]
363-
and (not b.graph_version or b.graph_version < current_graph_version)
364-
]
365-
364+
async def trigger_rebase_branches(db: InfrahubDatabase) -> None:
365+
"""Trigger rebase of non-default branches, also triggering migrations in the process."""
366+
branches = [b for b in await Branch.get_list(db=db) if b.name not in [registry.default_branch, GLOBAL_BRANCH_NAME]]
366367
if not branches:
367-
return True
368+
return
368369

369370
rprint(f"Planning rebase and migrations for {len(branches)} branches: {', '.join([b.name for b in branches])}")
370371

371372
for branch in branches:
372-
migrations = [
373-
m
374-
for m in await detect_migration_to_run(current_graph_version=branch.graph_version or 0)
375-
if isinstance(m, MigrationWithRebase)
376-
]
377-
rprint(
378-
f"Detected {len(migrations)} migrations to run against '{branch.name}' (ID: {branch.uuid}): {', '.join([m.name for m in migrations])}"
379-
)
380-
381-
for migration in migrations:
382-
execution_result = await migration.execute_against_branch(db=db, branch=branch)
383-
validation_result = None
384-
385-
if execution_result.success:
386-
validation_result = await migration.validate_migration(db=db)
387-
if validation_result.success and branch.status != BranchStatus.NEED_UPGRADE_REBASE:
388-
branch.graph_version = migration.minimum_version + 1
389-
await branch.save(db=db)
390-
rprint(f"Migration: {migration.name} {SUCCESS_BADGE}")
391-
392-
if not execution_result.success or (validation_result and not validation_result.success):
393-
rprint(f"Migration: {migration.name} {FAILED_BADGE}")
394-
for error in execution_result.errors:
395-
rprint(f" {error}")
396-
if validation_result and not validation_result.success:
397-
for error in validation_result.errors:
398-
rprint(f" {error}")
399-
return False
400-
401-
return True
373+
if branch.graph_version == GRAPH_VERSION:
374+
continue
375+
376+
rprint(f"Rebasing branch '{branch.name}' (ID: {branch.uuid})...", end="")
377+
try:
378+
await rebase_branch(
379+
branch=branch.name,
380+
context=InfrahubContext.init(
381+
branch=branch, account=AccountSession(auth_type=AuthType.NONE, authenticated=False, account_id="")
382+
),
383+
send_events=False,
384+
)
385+
rprint("done")
386+
except ValidationError:
387+
rprint("failed")
402388

403389

404390
async def initialize_internal_schema() -> None:

backend/infrahub/cli/upgrade.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
detect_migration_to_run,
3636
initialize_internal_schema,
3737
migrate_database,
38-
rebase_and_migrate_branches,
38+
trigger_rebase_branches,
3939
update_core_schema,
4040
)
4141

@@ -112,13 +112,8 @@ async def upgrade_cmd(
112112
# -------------------------------------------
113113
# Perform branch rebase and apply migrations to them
114114
# -------------------------------------------
115-
if rebase_branches and not await rebase_and_migrate_branches(
116-
db=dbdriver, current_graph_version=root_node.graph_version
117-
):
118-
# A migration failed, stop the upgrade process
119-
rprint("Upgrade cancelled due to branch rebase and migration failure.")
120-
await dbdriver.close()
121-
return
115+
if rebase_branches:
116+
await trigger_rebase_branches(db=dbdriver)
122117

123118
await dbdriver.close()
124119

backend/infrahub/core/branch/tasks.py

Lines changed: 68 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from infrahub.context import InfrahubContext # noqa: TC001 needed for prefect flow
1313
from infrahub.core import registry
1414
from infrahub.core.branch import Branch
15+
from infrahub.core.branch.enums import BranchStatus
1516
from infrahub.core.changelog.diff import DiffChangelogCollector, MigrationTracker
1617
from infrahub.core.constants import MutationAction
1718
from infrahub.core.diff.coordinator import DiffCoordinator
@@ -21,7 +22,10 @@
2122
from infrahub.core.diff.model.path import BranchTrackingId, EnrichedDiffRoot, EnrichedDiffRootMetadata
2223
from infrahub.core.diff.models import RequestDiffUpdate
2324
from infrahub.core.diff.repository.repository import DiffRepository
25+
from infrahub.core.graph import GRAPH_VERSION
2426
from infrahub.core.merge import BranchMerger
27+
from infrahub.core.migrations.exceptions import MigrationFailureError
28+
from infrahub.core.migrations.runner import MigrationRunner
2529
from infrahub.core.migrations.schema.models import SchemaApplyMigrationData
2630
from infrahub.core.migrations.schema.tasks import schema_apply_migrations
2731
from infrahub.core.timestamp import Timestamp
@@ -39,6 +43,7 @@
3943
from infrahub.workflows.catalogue import (
4044
BRANCH_CANCEL_PROPOSED_CHANGES,
4145
BRANCH_MERGE_POST_PROCESS,
46+
BRANCH_MIGRATE,
4247
DIFF_REFRESH_ALL,
4348
DIFF_UPDATE,
4449
GIT_REPOSITORIES_CREATE_BRANCH,
@@ -51,6 +56,7 @@
5156

5257
@flow(name="branch-rebase", flow_run_name="Rebase branch {branch}")
5358
async def rebase_branch(branch: str, context: InfrahubContext, send_events: bool = True) -> None: # noqa: PLR0915
59+
workflow = get_workflow()
5460
database = await get_database()
5561
async with database.start_session() as db:
5662
log = get_run_logger()
@@ -69,7 +75,7 @@ async def rebase_branch(branch: str, context: InfrahubContext, send_events: bool
6975
diff_repository=diff_repository,
7076
source_branch=obj,
7177
diff_locker=DiffLocker(),
72-
workflow=get_workflow(),
78+
workflow=workflow,
7379
)
7480

7581
enriched_diff_metadata = await diff_coordinator.update_branch_diff(base_branch=base_branch, diff_branch=obj)
@@ -156,41 +162,75 @@ async def rebase_branch(branch: str, context: InfrahubContext, send_events: bool
156162
target_branch_name=registry.default_branch,
157163
)
158164
if ipam_node_details:
159-
await get_workflow().submit_workflow(
165+
await workflow.submit_workflow(
160166
workflow=IPAM_RECONCILIATION,
161167
context=context,
162168
parameters={"branch": obj.name, "ipam_node_details": ipam_node_details},
163169
)
164170

165-
await get_workflow().submit_workflow(
166-
workflow=DIFF_REFRESH_ALL, context=context, parameters={"branch_name": obj.name}
167-
)
171+
await workflow.submit_workflow(workflow=DIFF_REFRESH_ALL, context=context, parameters={"branch_name": obj.name})
172+
await workflow.submit_workflow(workflow=BRANCH_MIGRATE, context=context, parameters={"branch_name": obj.name})
168173

169-
if send_events:
170-
# -------------------------------------------------------------
171-
# Generate an event to indicate that a branch has been rebased
172-
# -------------------------------------------------------------
173-
rebase_event = BranchRebasedEvent(
174-
branch_name=obj.name, branch_id=str(obj.uuid), meta=EventMeta(branch=obj, context=context)
175-
)
176-
events: list[InfrahubEvent] = [rebase_event]
177-
changelog_collector = DiffChangelogCollector(
178-
diff=default_branch_diff, branch=obj, db=db, migration_tracker=MigrationTracker(migrations=migrations)
174+
if not send_events:
175+
return
176+
177+
# -------------------------------------------------------------
178+
# Generate an event to indicate that a branch has been rebased
179+
# -------------------------------------------------------------
180+
rebase_event = BranchRebasedEvent(
181+
branch_name=obj.name, branch_id=str(obj.uuid), meta=EventMeta(branch=obj, context=context)
182+
)
183+
events: list[InfrahubEvent] = [rebase_event]
184+
changelog_collector = DiffChangelogCollector(
185+
diff=default_branch_diff, branch=obj, db=db, migration_tracker=MigrationTracker(migrations=migrations)
186+
)
187+
for action, node_changelog in changelog_collector.collect_changelogs():
188+
node_event_class = get_node_event(MutationAction.from_diff_action(diff_action=action))
189+
mutate_event = node_event_class(
190+
kind=node_changelog.node_kind,
191+
node_id=node_changelog.node_id,
192+
changelog=node_changelog,
193+
fields=node_changelog.updated_fields,
194+
meta=EventMeta.from_parent(parent=rebase_event, branch=obj),
179195
)
180-
for action, node_changelog in changelog_collector.collect_changelogs():
181-
node_event_class = get_node_event(MutationAction.from_diff_action(diff_action=action))
182-
mutate_event = node_event_class(
183-
kind=node_changelog.node_kind,
184-
node_id=node_changelog.node_id,
185-
changelog=node_changelog,
186-
fields=node_changelog.updated_fields,
187-
meta=EventMeta.from_parent(parent=rebase_event, branch=obj),
188-
)
189-
events.append(mutate_event)
196+
events.append(mutate_event)
190197

191-
event_service = await get_event_service()
192-
for event in events:
193-
await event_service.send(event)
198+
event_service = await get_event_service()
199+
for event in events:
200+
await event_service.send(event)
201+
202+
203+
@flow(name="migrate_branch", flow_run_name="Apply migrations to branch {branch}")
204+
async def migrate_branch(branch_name: str, context: InfrahubContext) -> None: # noqa: ARG001
205+
db = await get_database()
206+
log = get_run_logger()
207+
208+
branch = await registry.get_branch(db=db, branch=branch_name)
209+
210+
if branch.graph_version == GRAPH_VERSION:
211+
log.info(f"Branch '{branch.name}' is up-to-date")
212+
return
213+
214+
migration_runner = MigrationRunner(branch=branch)
215+
if not migration_runner.has_migrations():
216+
log.info(f"No migrations detected for branch '{branch.name}'")
217+
return
218+
219+
# Branch status will remain as so if the migration process fails
220+
# This will help user to know that a branch is in an invalid state to be used properly and that actions need to be taken
221+
branch.status = BranchStatus.NEED_UPGRADE_REBASE
222+
await branch.save(db=db)
223+
224+
try:
225+
await migration_runner.run(db=db)
226+
except MigrationFailureError as exc:
227+
log.error(f"Failed to migrate branch '{branch.name}': {exc.errors}")
228+
return
229+
230+
if branch.status == BranchStatus.NEED_UPGRADE_REBASE:
231+
branch.status = BranchStatus.OPEN
232+
branch.graph_version = GRAPH_VERSION
233+
await branch.save(db=db)
194234

195235

196236
@flow(name="branch-merge", flow_run_name="Merge branch {branch} into main")
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
class MigrationFailureError(Exception):
2+
def __init__(self, errors: list[str]) -> None:
3+
super().__init__()
4+
self.errors = errors

backend/infrahub/core/migrations/graph/m043_backfill_hfid_display_label_in_db.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,7 @@ async def _do_one_schema_batch(
447447

448448
print("done")
449449

450-
async def process_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult:
451-
root_node = await get_root_node(db=db, initialize=False)
452-
default_branch = root_node.default_branch
450+
async def execute_against_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult:
453451
schema_manager = SchemaManager()
454452
internal_schema_root = SchemaRoot(**internal_schema)
455453
schema_manager.register_schema(schema=internal_schema_root)
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
from __future__ import annotations
2+
3+
from typing import TYPE_CHECKING, Sequence
4+
5+
from infrahub.core.migrations.graph import MIGRATIONS
6+
7+
from .exceptions import MigrationFailureError
8+
from .shared import MigrationWithRebase
9+
10+
if TYPE_CHECKING:
11+
from infrahub.core.branch import Branch
12+
from infrahub.database import InfrahubDatabase
13+
14+
15+
class MigrationRunner:
16+
def __init__(self, branch: Branch) -> None:
17+
self.branch = branch
18+
self.applicable_migrations = self._get_applicable_migrations()
19+
20+
def _get_applicable_migrations(self) -> Sequence[MigrationWithRebase]:
21+
applicable_migrations = []
22+
for migration_class in [m for m in MIGRATIONS if issubclass(m, MigrationWithRebase)]:
23+
migration = migration_class.init()
24+
if self.branch.graph_version and self.branch.graph_version > migration.minimum_version:
25+
continue
26+
applicable_migrations.append(migration)
27+
28+
return applicable_migrations
29+
30+
def has_migrations(self) -> bool:
31+
return bool(self.applicable_migrations)
32+
33+
async def run(self, db: InfrahubDatabase) -> None:
34+
if not self.has_migrations():
35+
return
36+
37+
for migration in self.applicable_migrations:
38+
execution_result = await migration.execute_against_branch(db=db, branch=self.branch)
39+
validation_result = None
40+
41+
if execution_result.success:
42+
validation_result = await migration.validate_migration(db=db)
43+
# FIXME: graph_version should be changed elsewhere?
44+
# if validation_result.success and self.branch.status != BranchStatus.NEED_UPGRADE_REBASE:
45+
# self.branch.graph_version = migration.minimum_version + 1
46+
# await self.branch.save(db=db)
47+
48+
if not execution_result.success or (validation_result and not validation_result.success):
49+
if execution_result.errors:
50+
raise MigrationFailureError(errors=execution_result.errors)
51+
52+
if validation_result and not validation_result.success:
53+
raise MigrationFailureError(errors=validation_result.errors)

0 commit comments

Comments
 (0)