Skip to content

Commit adde2bf

Browse files
committed
Allow migrations to run against individual branches
1 parent 45e6ed5 commit adde2bf

File tree

5 files changed

+72
-53
lines changed

5 files changed

+72
-53
lines changed

backend/infrahub/cli/db.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from infrahub import config
2323
from infrahub.core import registry
2424
from infrahub.core.branch import Branch
25+
from infrahub.core.branch.enums import BranchStatus
2526
from infrahub.core.constants import GLOBAL_BRANCH_NAME
2627
from infrahub.core.graph import GRAPH_VERSION
2728
from infrahub.core.graph.constraints import ConstraintManagerBase, ConstraintManagerMemgraph, ConstraintManagerNeo4j
@@ -175,7 +176,10 @@ async def migrate_cmd(
175176
context: CliContext = ctx.obj
176177
dbdriver = await context.init_db(retry=1)
177178

178-
migrations = await detect_migration_to_run(db=dbdriver, migration_number=migration_number)
179+
root_node = await get_root_node(db=dbdriver)
180+
migrations = await detect_migration_to_run(
181+
current_graph_version=root_node.graph_version, migration_number=migration_number
182+
)
179183

180184
if check or not migrations:
181185
return
@@ -344,17 +348,16 @@ async def index(
344348

345349

346350
async def detect_migration_to_run(
347-
db: InfrahubDatabase, migration_number: int | str | None = None
351+
current_graph_version: int, migration_number: int | str | None = None
348352
) -> Sequence[GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationWithRebase]:
349353
"""Return a sequence of migrations to apply to upgrade the database."""
350354
rprint("Checking current state of the database")
351355
migrations: list[GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationWithRebase] = []
352356

353-
root_node = await get_root_node(db=db)
354357
if migration_number:
355358
migration = get_migration_by_number(migration_number)
356359
migrations.append(migration)
357-
if root_node.graph_version > migration.minimum_version:
360+
if current_graph_version > migration.minimum_version:
358361
rprint(
359362
f"Migration {migration_number} already applied. To apply again, run the command without the --check flag."
360363
)
@@ -363,13 +366,13 @@ async def detect_migration_to_run(
363366
f"Migration {migration_number} needs to be applied. Run `infrahub db migrate` to apply all outstanding migrations."
364367
)
365368
else:
366-
migrations.extend(await get_graph_migrations(root=root_node))
369+
migrations.extend(await get_graph_migrations(current_graph_version=current_graph_version))
367370
if not migrations:
368-
rprint(f"Database up-to-date (v{root_node.graph_version}), no migration to execute.")
371+
rprint(f"Database up-to-date (v{current_graph_version}), no migration to execute.")
369372
return []
370373

371374
rprint(
372-
f"Database needs to be updated (v{root_node.graph_version} -> v{GRAPH_VERSION}), {len(migrations)} migrations pending"
375+
f"Database needs to be updated (v{current_graph_version} -> v{GRAPH_VERSION}), {len(migrations)} migrations pending"
373376
)
374377
return migrations
375378

@@ -387,6 +390,9 @@ async def migrate_database(
387390
db: The database object.
388391
migration_number: If provided, the function will only apply the migration with the given number. Defaults to None.
389392
"""
393+
if not migrations:
394+
return True
395+
390396
if initialize:
391397
await initialize_registry(db=db)
392398

@@ -415,33 +421,49 @@ async def migrate_database(
415421
return True
416422

417423

418-
async def rebase_and_migrate_branches(
419-
db: InfrahubDatabase,
420-
migrations: Sequence[GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationWithRebase],
421-
) -> bool:
424+
async def rebase_and_migrate_branches(db: InfrahubDatabase, current_graph_version: int) -> bool:
422425
"""Only applies migrations that aim at rebasing branches."""
423-
branches = [b for b in await Branch.get_list(db=db) if b.name not in [registry.default_branch, GLOBAL_BRANCH_NAME]]
426+
branches = [
427+
b
428+
for b in await Branch.get_list(db=db)
429+
if b.name not in [registry.default_branch, GLOBAL_BRANCH_NAME]
430+
and (not b.graph_version or b.graph_version < current_graph_version)
431+
]
432+
433+
if not branches:
434+
return True
435+
424436
rprint(f"Planning rebase and migrations for {len(branches)} branches: {', '.join([b.name for b in branches])}")
425437

426-
rebase_migrations = [m for m in migrations if isinstance(m, MigrationWithRebase)]
438+
for branch in branches:
439+
migrations = [
440+
m
441+
for m in await detect_migration_to_run(current_graph_version=branch.graph_version or current_graph_version)
442+
if isinstance(m, MigrationWithRebase)
443+
]
444+
rprint(
445+
f"Detected {len(migrations)} migrations to run against '{branch.name}' (ID: {branch.uuid}): {', '.join([m.name for m in migrations])}"
446+
)
427447

428-
for migration in rebase_migrations:
429-
execution_result = await migration.execute_against_branches(db=db, branches=branches)
430-
validation_result = None
448+
for migration in migrations:
449+
execution_result = await migration.execute_against_branch(db=db, branch=branch)
450+
validation_result = None
431451

432-
if execution_result.success:
433-
validation_result = await migration.validate_migration(db=db)
434-
if validation_result.success:
435-
rprint(f"Migration: {migration.name} {SUCCESS_BADGE}")
452+
if execution_result.success:
453+
validation_result = await migration.validate_migration(db=db)
454+
if validation_result.success and branch.status != BranchStatus.NEED_UPGRADE_REBASE:
455+
branch.graph_version = migration.minimum_version + 1
456+
await branch.save(db=db)
457+
rprint(f"Migration: {migration.name} {SUCCESS_BADGE}")
436458

437-
if not execution_result.success or (validation_result and not validation_result.success):
438-
rprint(f"Migration: {migration.name} {FAILED_BADGE}")
439-
for error in execution_result.errors:
440-
rprint(f" {error}")
441-
if validation_result and not validation_result.success:
442-
for error in validation_result.errors:
459+
if not execution_result.success or (validation_result and not validation_result.success):
460+
rprint(f"Migration: {migration.name} {FAILED_BADGE}")
461+
for error in execution_result.errors:
443462
rprint(f" {error}")
444-
return False
463+
if validation_result and not validation_result.success:
464+
for error in validation_result.errors:
465+
rprint(f" {error}")
466+
return False
445467

446468
return True
447469

backend/infrahub/cli/upgrade.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
from rich import print as rprint
1212

1313
from infrahub import config
14-
from infrahub.core.initialization import create_anonymous_role, create_default_account_groups, initialize_registry
14+
from infrahub.core.initialization import (
15+
create_anonymous_role,
16+
create_default_account_groups,
17+
get_root_node,
18+
initialize_registry,
19+
)
1520
from infrahub.core.manager import NodeManager
1621
from infrahub.core.protocols import CoreAccount, CoreObjectPermission
1722
from infrahub.dependencies.registry import build_component_registry
@@ -63,6 +68,8 @@ async def upgrade_cmd(
6368

6469
build_component_registry()
6570

71+
root_node = await get_root_node(db=dbdriver)
72+
6673
# NOTE add step to validate if the database and the task manager are reachable
6774

6875
# -------------------------------------------
@@ -73,8 +80,8 @@ async def upgrade_cmd(
7380
# Upgrade Infrahub Database and Schema
7481
# -------------------------------------------
7582

76-
migrations = await detect_migration_to_run(db=dbdriver)
77-
if check or not migrations:
83+
migrations = await detect_migration_to_run(current_graph_version=root_node.graph_version)
84+
if check:
7885
return
7986

8087
if not await migrate_database(db=dbdriver, initialize=False, migrations=migrations):
@@ -86,7 +93,7 @@ async def upgrade_cmd(
8693
await initialize_internal_schema()
8794
await update_core_schema(db=dbdriver, initialize=False)
8895

89-
if not await rebase_and_migrate_branches(db=dbdriver, migrations=migrations):
96+
if not await rebase_and_migrate_branches(db=dbdriver, current_graph_version=root_node.graph_version):
9097
# A migration failed, stop the upgrade process
9198
rprint("Upgrade cancelled due to branch rebase and migration failure.")
9299
await dbdriver.close()

backend/infrahub/core/migrations/graph/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@
4747
from .m043_backfill_hfid_display_label_in_db import Migration043
4848

4949
if TYPE_CHECKING:
50-
from infrahub.core.root import Root
51-
5250
from ..shared import ArbitraryMigration, GraphMigration, InternalSchemaMigration, MigrationWithRebase
5351

5452
MIGRATIONS: list[type[GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationWithRebase]] = [
@@ -99,12 +97,12 @@
9997

10098

10199
async def get_graph_migrations(
102-
root: Root,
100+
current_graph_version: int,
103101
) -> Sequence[GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationWithRebase]:
104102
applicable_migrations = []
105103
for migration_class in MIGRATIONS:
106104
migration = migration_class.init()
107-
if root.graph_version > migration.minimum_version:
105+
if current_graph_version > migration.minimum_version:
108106
continue
109107
applicable_migrations.append(migration)
110108

backend/infrahub/core/migrations/graph/m043_backfill_hfid_display_label_in_db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ async def _update_batch(
5353

5454
progress.update(task_id=update_task, advance=1)
5555

56-
async def execute_against_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult:
56+
async def process_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult:
5757
result = MigrationResult()
5858

5959
schemas_to_update: dict[MainSchemaTypes, int] = {}

backend/infrahub/core/migrations/shared.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -269,34 +269,26 @@ async def rebase_branch(self, branch: Branch) -> bool:
269269
async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult:
270270
raise NotImplementedError()
271271

272-
async def execute_against_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult:
272+
async def process_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult:
273273
raise NotImplementedError()
274274

275-
async def execute_against_branches(self, db: InfrahubDatabase, branches: Sequence[Branch]) -> MigrationResult:
276-
result = MigrationResult()
277-
278-
for branch in branches:
279-
await registry.schema.load_schema(db=db, branch=branch)
275+
async def execute_against_branch(
276+
self, db: InfrahubDatabase, branch: Branch, skip_rebase: bool = False
277+
) -> MigrationResult:
278+
await registry.schema.load_schema(db=db, branch=branch)
280279

280+
if not skip_rebase:
281281
if not await self.rebase_branch(branch=branch):
282282
branch.status = BranchStatus.NEED_UPGRADE_REBASE
283283
await branch.save(db=db)
284-
continue
285-
286-
r = await self.execute_against_branch(db=db, branch=branch)
287-
result.nbr_migrations_executed += 1
288-
if r.errors:
289-
result.errors.extend(r.errors)
290-
if r.success:
291-
branch.graph_version = self.minimum_version + 1
292-
await branch.save(db=db)
284+
return MigrationResult()
293285

294-
return result
286+
return await self.process_branch(db=db, branch=branch)
295287

296288
async def execute(self, db: InfrahubDatabase) -> MigrationResult:
297289
from infrahub.core.initialization import initialization
298290

299291
initialize_lock()
300292
await initialization(db=db)
301293

302-
return await self.execute_against_branch(db=db, branch=registry.get_branch_from_registry())
294+
return await self.execute_against_branch(db=db, branch=registry.get_branch_from_registry(), skip_rebase=True)

0 commit comments

Comments
 (0)