2222from infrahub import config
2323from infrahub .core import registry
2424from infrahub .core .branch import Branch
25+ from infrahub .core .branch .enums import BranchStatus
2526from infrahub .core .constants import GLOBAL_BRANCH_NAME
2627from infrahub .core .graph import GRAPH_VERSION
2728from infrahub .core .graph .constraints import ConstraintManagerBase , ConstraintManagerMemgraph , ConstraintManagerNeo4j
@@ -175,7 +176,10 @@ async def migrate_cmd(
175176 context : CliContext = ctx .obj
176177 dbdriver = await context .init_db (retry = 1 )
177178
178- migrations = await detect_migration_to_run (db = dbdriver , migration_number = migration_number )
179+ root_node = await get_root_node (db = dbdriver )
180+ migrations = await detect_migration_to_run (
181+ current_graph_version = root_node .graph_version , migration_number = migration_number
182+ )
179183
180184 if check or not migrations :
181185 return
@@ -344,17 +348,16 @@ async def index(
344348
345349
346350async def detect_migration_to_run (
347- db : InfrahubDatabase , migration_number : int | str | None = None
351+ current_graph_version : int , migration_number : int | str | None = None
348352) -> Sequence [GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationWithRebase ]:
349353 """Return a sequence of migrations to apply to upgrade the database."""
350354 rprint ("Checking current state of the database" )
351355 migrations : list [GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationWithRebase ] = []
352356
353- root_node = await get_root_node (db = db )
354357 if migration_number :
355358 migration = get_migration_by_number (migration_number )
356359 migrations .append (migration )
357- if root_node . graph_version > migration .minimum_version :
360+ if current_graph_version > migration .minimum_version :
358361 rprint (
359362 f"Migration { migration_number } already applied. To apply again, run the command without the --check flag."
360363 )
@@ -363,13 +366,13 @@ async def detect_migration_to_run(
363366 f"Migration { migration_number } needs to be applied. Run `infrahub db migrate` to apply all outstanding migrations."
364367 )
365368 else :
366- migrations .extend (await get_graph_migrations (root = root_node ))
369+ migrations .extend (await get_graph_migrations (current_graph_version = current_graph_version ))
367370 if not migrations :
368- rprint (f"Database up-to-date (v{ root_node . graph_version } ), no migration to execute." )
371+ rprint (f"Database up-to-date (v{ current_graph_version } ), no migration to execute." )
369372 return []
370373
371374 rprint (
372- f"Database needs to be updated (v{ root_node . graph_version } -> v{ GRAPH_VERSION } ), { len (migrations )} migrations pending"
375+ f"Database needs to be updated (v{ current_graph_version } -> v{ GRAPH_VERSION } ), { len (migrations )} migrations pending"
373376 )
374377 return migrations
375378
@@ -387,6 +390,9 @@ async def migrate_database(
387390 db: The database object.
388391 migration_number: If provided, the function will only apply the migration with the given number. Defaults to None.
389392 """
393+ if not migrations :
394+ return True
395+
390396 if initialize :
391397 await initialize_registry (db = db )
392398
@@ -415,33 +421,49 @@ async def migrate_database(
415421 return True
416422
417423
418- async def rebase_and_migrate_branches (
419- db : InfrahubDatabase ,
420- migrations : Sequence [GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationWithRebase ],
421- ) -> bool :
424+ async def rebase_and_migrate_branches (db : InfrahubDatabase , current_graph_version : int ) -> bool :
422425 """Only applies migrations that aim at rebasing branches."""
423- branches = [b for b in await Branch .get_list (db = db ) if b .name not in [registry .default_branch , GLOBAL_BRANCH_NAME ]]
426+ branches = [
427+ b
428+ for b in await Branch .get_list (db = db )
429+ if b .name not in [registry .default_branch , GLOBAL_BRANCH_NAME ]
430+ and (not b .graph_version or b .graph_version < current_graph_version )
431+ ]
432+
433+ if not branches :
434+ return True
435+
424436 rprint (f"Planning rebase and migrations for { len (branches )} branches: { ', ' .join ([b .name for b in branches ])} " )
425437
426- rebase_migrations = [m for m in migrations if isinstance (m , MigrationWithRebase )]
438+ for branch in branches :
439+ migrations = [
440+ m
441+ for m in await detect_migration_to_run (current_graph_version = branch .graph_version or current_graph_version )
442+ if isinstance (m , MigrationWithRebase )
443+ ]
444+ rprint (
445+ f"Detected { len (migrations )} migrations to run against '{ branch .name } ' (ID: { branch .uuid } ): { ', ' .join ([m .name for m in migrations ])} "
446+ )
427447
428- for migration in rebase_migrations :
429- execution_result = await migration .execute_against_branches (db = db , branches = branches )
430- validation_result = None
448+ for migration in migrations :
449+ execution_result = await migration .execute_against_branch (db = db , branch = branch )
450+ validation_result = None
431451
432- if execution_result .success :
433- validation_result = await migration .validate_migration (db = db )
434- if validation_result .success :
435- rprint (f"Migration: { migration .name } { SUCCESS_BADGE } " )
452+ if execution_result .success :
453+ validation_result = await migration .validate_migration (db = db )
454+ if validation_result .success and branch .status != BranchStatus .NEED_UPGRADE_REBASE :
455+ branch .graph_version = migration .minimum_version + 1
456+ await branch .save (db = db )
457+ rprint (f"Migration: { migration .name } { SUCCESS_BADGE } " )
436458
437- if not execution_result .success or (validation_result and not validation_result .success ):
438- rprint (f"Migration: { migration .name } { FAILED_BADGE } " )
439- for error in execution_result .errors :
440- rprint (f" { error } " )
441- if validation_result and not validation_result .success :
442- for error in validation_result .errors :
459+ if not execution_result .success or (validation_result and not validation_result .success ):
460+ rprint (f"Migration: { migration .name } { FAILED_BADGE } " )
461+ for error in execution_result .errors :
443462 rprint (f" { error } " )
444- return False
463+ if validation_result and not validation_result .success :
464+ for error in validation_result .errors :
465+ rprint (f" { error } " )
466+ return False
445467
446468 return True
447469
0 commit comments