2020from infrahub import config
2121from infrahub .core import registry
2222from infrahub .core .branch import Branch
23+ from infrahub .core .branch .enums import BranchStatus
2324from infrahub .core .constants import GLOBAL_BRANCH_NAME
2425from infrahub .core .graph import GRAPH_VERSION
2526from infrahub .core .graph .constraints import ConstraintManagerBase , ConstraintManagerMemgraph , ConstraintManagerNeo4j
@@ -108,7 +109,10 @@ async def migrate_cmd(
108109 context : CliContext = ctx .obj
109110 dbdriver = await context .init_db (retry = 1 )
110111
111- migrations = await detect_migration_to_run (db = dbdriver , migration_number = migration_number )
112+ root_node = await get_root_node (db = dbdriver )
113+ migrations = await detect_migration_to_run (
114+ current_graph_version = root_node .graph_version , migration_number = migration_number
115+ )
112116
113117 if check or not migrations :
114118 return
@@ -277,17 +281,16 @@ async def index(
277281
278282
279283async def detect_migration_to_run (
280- db : InfrahubDatabase , migration_number : int | str | None = None
284+ current_graph_version : int , migration_number : int | str | None = None
281285) -> Sequence [GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationWithRebase ]:
282286 """Return a sequence of migrations to apply to upgrade the database."""
283287 rprint ("Checking current state of the database" )
284288 migrations : list [GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationWithRebase ] = []
285289
286- root_node = await get_root_node (db = db )
287290 if migration_number :
288291 migration = get_migration_by_number (migration_number )
289292 migrations .append (migration )
290- if root_node . graph_version > migration .minimum_version :
293+ if current_graph_version > migration .minimum_version :
291294 rprint (
292295 f"Migration { migration_number } already applied. To apply again, run the command without the --check flag."
293296 )
@@ -296,13 +299,13 @@ async def detect_migration_to_run(
296299 f"Migration { migration_number } needs to be applied. Run `infrahub db migrate` to apply all outstanding migrations."
297300 )
298301 else :
299- migrations .extend (await get_graph_migrations (root = root_node ))
302+ migrations .extend (await get_graph_migrations (current_graph_version = current_graph_version ))
300303 if not migrations :
301- rprint (f"Database up-to-date (v{ root_node . graph_version } ), no migration to execute." )
304+ rprint (f"Database up-to-date (v{ current_graph_version } ), no migration to execute." )
302305 return []
303306
304307 rprint (
305- f"Database needs to be updated (v{ root_node . graph_version } -> v{ GRAPH_VERSION } ), { len (migrations )} migrations pending"
308+ f"Database needs to be updated (v{ current_graph_version } -> v{ GRAPH_VERSION } ), { len (migrations )} migrations pending"
306309 )
307310 return migrations
308311
@@ -320,6 +323,9 @@ async def migrate_database(
320323 db: The database object.
321324 migration_number: If provided, the function will only apply the migration with the given number. Defaults to None.
322325 """
326+ if not migrations :
327+ return True
328+
323329 if initialize :
324330 await initialize_registry (db = db )
325331
@@ -348,33 +354,49 @@ async def migrate_database(
348354 return True
349355
350356
351- async def rebase_and_migrate_branches (
352- db : InfrahubDatabase ,
353- migrations : Sequence [GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationWithRebase ],
354- ) -> bool :
357+ async def rebase_and_migrate_branches (db : InfrahubDatabase , current_graph_version : int ) -> bool :
355358 """Only applies migrations that aim at rebasing branches."""
356- branches = [b for b in await Branch .get_list (db = db ) if b .name not in [registry .default_branch , GLOBAL_BRANCH_NAME ]]
359+ branches = [
360+ b
361+ for b in await Branch .get_list (db = db )
362+ if b .name not in [registry .default_branch , GLOBAL_BRANCH_NAME ]
363+ and (not b .graph_version or b .graph_version < current_graph_version )
364+ ]
365+
366+ if not branches :
367+ return True
368+
357369 rprint (f"Planning rebase and migrations for { len (branches )} branches: { ', ' .join ([b .name for b in branches ])} " )
358370
359- rebase_migrations = [m for m in migrations if isinstance (m , MigrationWithRebase )]
371+ for branch in branches :
372+ migrations = [
373+ m
374+ for m in await detect_migration_to_run (current_graph_version = branch .graph_version or current_graph_version )
375+ if isinstance (m , MigrationWithRebase )
376+ ]
377+ rprint (
378+ f"Detected { len (migrations )} migrations to run against '{ branch .name } ' (ID: { branch .uuid } ): { ', ' .join ([m .name for m in migrations ])} "
379+ )
360380
361- for migration in rebase_migrations :
362- execution_result = await migration .execute_against_branches (db = db , branches = branches )
363- validation_result = None
381+ for migration in migrations :
382+ execution_result = await migration .execute_against_branch (db = db , branch = branch )
383+ validation_result = None
364384
365- if execution_result .success :
366- validation_result = await migration .validate_migration (db = db )
367- if validation_result .success :
368- rprint (f"Migration: { migration .name } { SUCCESS_BADGE } " )
385+ if execution_result .success :
386+ validation_result = await migration .validate_migration (db = db )
387+ if validation_result .success and branch .status != BranchStatus .NEED_UPGRADE_REBASE :
388+ branch .graph_version = migration .minimum_version + 1
389+ await branch .save (db = db )
390+ rprint (f"Migration: { migration .name } { SUCCESS_BADGE } " )
369391
370- if not execution_result .success or (validation_result and not validation_result .success ):
371- rprint (f"Migration: { migration .name } { FAILED_BADGE } " )
372- for error in execution_result .errors :
373- rprint (f" { error } " )
374- if validation_result and not validation_result .success :
375- for error in validation_result .errors :
392+ if not execution_result .success or (validation_result and not validation_result .success ):
393+ rprint (f"Migration: { migration .name } { FAILED_BADGE } " )
394+ for error in execution_result .errors :
376395 rprint (f" { error } " )
377- return False
396+ if validation_result and not validation_result .success :
397+ for error in validation_result .errors :
398+ rprint (f" { error } " )
399+ return False
378400
379401 return True
380402
0 commit comments