Skip to content

Commit 9e8eebe

Browse files
authored
Allow rebase of branches during migration (#7447)
These changes add a new flow to apply migrations (schema/data ones) inside branches. The upgrade process (via the provided command) is also updated with a new `--rebase-branches` flag to trigger branch rebases as part of the upgrade process. A new `MigrationRequiringRebase` class is introduced with a couple of methods that need to be defined on classes inheriting it. One method defines what to do on the default branch while the other one defines what to do for all other branches. The branch nodes can have a new `NEED_UPGRADE_REBASE` status and they also now have a graph_version property to know if a branch is up-to-date regarding the graph structure. A new prefect event is also introduced and triggered when the `branch-migrate` flow completes for a branch.
1 parent 128fa7b commit 9e8eebe

File tree

22 files changed

+419
-71
lines changed

22 files changed

+419
-71
lines changed

backend/infrahub/cli/db.py

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@
1818
from rich.table import Table
1919

2020
from infrahub import config
21+
from infrahub.auth import AccountSession, AuthType
22+
from infrahub.context import InfrahubContext
2123
from infrahub.core import registry
24+
from infrahub.core.branch import Branch
25+
from infrahub.core.branch.tasks import rebase_branch
26+
from infrahub.core.constants import GLOBAL_BRANCH_NAME
2227
from infrahub.core.graph import GRAPH_VERSION
2328
from infrahub.core.graph.constraints import ConstraintManagerBase, ConstraintManagerMemgraph, ConstraintManagerNeo4j
2429
from infrahub.core.graph.index import node_indexes, rel_indexes
@@ -30,10 +35,8 @@
3035
GraphRelationshipIsPartOf,
3136
GraphRelationshipProperties,
3237
)
33-
from infrahub.core.initialization import (
34-
get_root_node,
35-
initialize_registry,
36-
)
38+
from infrahub.core.initialization import get_root_node, initialize_registry
39+
from infrahub.core.migrations.exceptions import MigrationFailureError
3740
from infrahub.core.migrations.graph import get_graph_migrations, get_migration_by_number
3841
from infrahub.core.migrations.schema.models import SchemaApplyMigrationData
3942
from infrahub.core.migrations.schema.tasks import schema_apply_migrations
@@ -45,6 +48,7 @@
4548
from infrahub.database import DatabaseType
4649
from infrahub.database.memgraph import IndexManagerMemgraph
4750
from infrahub.database.neo4j import IndexManagerNeo4j
51+
from infrahub.exceptions import ValidationError
4852

4953
from .constants import ERROR_BADGE, FAILED_BADGE, SUCCESS_BADGE
5054
from .db_commands.check_inheritance import check_inheritance
@@ -59,7 +63,7 @@ def get_timestamp_string() -> str:
5963

6064
if TYPE_CHECKING:
6165
from infrahub.cli.context import CliContext
62-
from infrahub.core.migrations.shared import ArbitraryMigration, GraphMigration, InternalSchemaMigration
66+
from infrahub.core.migrations.shared import MigrationTypes
6367
from infrahub.database import InfrahubDatabase
6468
from infrahub.database.index import IndexManagerBase
6569

@@ -105,7 +109,15 @@ async def migrate_cmd(
105109
context: CliContext = ctx.obj
106110
dbdriver = await context.init_db(retry=1)
107111

108-
await migrate_database(db=dbdriver, initialize=True, check=check, migration_number=migration_number)
112+
root_node = await get_root_node(db=dbdriver)
113+
migrations = await detect_migration_to_run(
114+
current_graph_version=root_node.graph_version, migration_number=migration_number
115+
)
116+
117+
if check or not migrations:
118+
return
119+
120+
await migrate_database(db=dbdriver, migrations=migrations, initialize=True)
109121

110122
await dbdriver.close()
111123

@@ -268,49 +280,55 @@ async def index(
268280
await dbdriver.close()
269281

270282

283+
async def detect_migration_to_run(
284+
current_graph_version: int, migration_number: int | str | None = None
285+
) -> Sequence[MigrationTypes]:
286+
"""Return a sequence of migrations to apply to upgrade the database."""
287+
rprint("Checking current state of the database")
288+
migrations: list[MigrationTypes] = []
289+
290+
if migration_number:
291+
migration = get_migration_by_number(migration_number)
292+
migrations.append(migration)
293+
if current_graph_version > migration.minimum_version:
294+
rprint(
295+
f"Migration {migration_number} already applied. To apply again, run the command without the --check flag."
296+
)
297+
return []
298+
rprint(
299+
f"Migration {migration_number} needs to be applied. Run `infrahub db migrate` to apply all outstanding migrations."
300+
)
301+
else:
302+
migrations.extend(await get_graph_migrations(current_graph_version=current_graph_version))
303+
if not migrations:
304+
rprint(f"Database up-to-date (v{current_graph_version}), no migration to execute.")
305+
return []
306+
307+
rprint(
308+
f"Database needs to be updated (v{current_graph_version} -> v{GRAPH_VERSION}), {len(migrations)} migrations pending"
309+
)
310+
return migrations
311+
312+
271313
async def migrate_database(
272-
db: InfrahubDatabase, initialize: bool = False, check: bool = False, migration_number: int | str | None = None
314+
db: InfrahubDatabase, migrations: Sequence[MigrationTypes], initialize: bool = False
273315
) -> bool:
274316
"""Apply the latest migrations to the database, this function will print the status directly in the console.
275317
276318
Returns a boolean indicating whether a migration failed or if all migrations succeeded.
277319
278320
Args:
279321
db: The database object.
280-
check: If True, the function will only check the status of the database and not apply the migrations. Defaults to False.
281-
migration_number: If provided, the function will only apply the migration with the given number. Defaults to None.
322+
migrations: Sequence of migrations to apply.
323+
initialize: Whether to initialize the registry before running migrations.
282324
"""
283-
rprint("Checking current state of the Database")
325+
if not migrations:
326+
return True
284327

285328
if initialize:
286329
await initialize_registry(db=db)
287330

288331
root_node = await get_root_node(db=db)
289-
if migration_number:
290-
migration = get_migration_by_number(migration_number)
291-
migrations: Sequence[GraphMigration | InternalSchemaMigration | ArbitraryMigration] = [migration]
292-
if check:
293-
if root_node.graph_version > migration.minimum_version:
294-
rprint(
295-
f"Migration {migration_number} already applied. To apply again, run the command without the --check flag."
296-
)
297-
return True
298-
rprint(
299-
f"Migration {migration_number} needs to be applied. Run `infrahub db migrate` to apply all outstanding migrations."
300-
)
301-
return False
302-
else:
303-
migrations = await get_graph_migrations(root=root_node)
304-
if not migrations:
305-
rprint(f"Database up-to-date (v{root_node.graph_version}), no migration to execute.")
306-
return True
307-
308-
rprint(
309-
f"Database needs to be updated (v{root_node.graph_version} -> v{GRAPH_VERSION}), {len(migrations)} migrations pending"
310-
)
311-
312-
if check:
313-
return True
314332

315333
for migration in migrations:
316334
execution_result = await migration.execute(db=db)
@@ -335,6 +353,36 @@ async def migrate_database(
335353
return True
336354

337355

356+
async def trigger_rebase_branches(db: InfrahubDatabase) -> None:
357+
"""Trigger rebase of non-default branches, also triggering migrations in the process."""
358+
branches = [b for b in await Branch.get_list(db=db) if b.name not in [registry.default_branch, GLOBAL_BRANCH_NAME]]
359+
if not branches:
360+
return
361+
362+
rprint(f"Planning rebase and migrations for {len(branches)} branches: {', '.join([b.name for b in branches])}")
363+
364+
for branch in branches:
365+
if branch.graph_version == GRAPH_VERSION:
366+
rprint(
367+
f"Ignoring branch rebase and migrations for '{branch.name}' (ID: {branch.uuid}), it is already up-to-date"
368+
)
369+
continue
370+
371+
rprint(f"Rebasing branch '{branch.name}' (ID: {branch.uuid})...", end="")
372+
try:
373+
await registry.schema.load_schema(db=db, branch=branch)
374+
await rebase_branch(
375+
branch=branch.name,
376+
context=InfrahubContext.init(
377+
branch=branch, account=AccountSession(auth_type=AuthType.NONE, authenticated=False, account_id="")
378+
),
379+
send_events=False,
380+
)
381+
rprint(SUCCESS_BADGE)
382+
except (ValidationError, MigrationFailureError):
383+
rprint(FAILED_BADGE)
384+
385+
338386
async def initialize_internal_schema() -> None:
339387
registry.schema = SchemaManager()
340388
schema = SchemaRoot(**internal_schema)

backend/infrahub/cli/upgrade.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
from rich import print as rprint
1212

1313
from infrahub import config
14-
from infrahub.core.initialization import create_anonymous_role, create_default_account_groups, initialize_registry
14+
from infrahub.core.initialization import (
15+
create_anonymous_role,
16+
create_default_account_groups,
17+
get_root_node,
18+
initialize_registry,
19+
)
1520
from infrahub.core.manager import NodeManager
1621
from infrahub.core.protocols import CoreAccount, CoreObjectPermission
1722
from infrahub.dependencies.registry import build_component_registry
@@ -26,7 +31,13 @@
2631
setup_worker_pools,
2732
)
2833

29-
from .db import initialize_internal_schema, migrate_database, update_core_schema
34+
from .db import (
35+
detect_migration_to_run,
36+
initialize_internal_schema,
37+
migrate_database,
38+
trigger_rebase_branches,
39+
update_core_schema,
40+
)
3041

3142
if TYPE_CHECKING:
3243
from infrahub.cli.context import CliContext
@@ -40,6 +51,7 @@ async def upgrade_cmd(
4051
ctx: typer.Context,
4152
config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"),
4253
check: bool = typer.Option(False, help="Check the state of the system without upgrading."),
54+
rebase_branches: bool = typer.Option(False, help="Rebase and apply migrations to branches if required."),
4355
) -> None:
4456
"""Upgrade Infrahub to the latest version."""
4557

@@ -57,6 +69,8 @@ async def upgrade_cmd(
5769

5870
build_component_registry()
5971

72+
root_node = await get_root_node(db=dbdriver)
73+
6074
# NOTE add step to validate if the database and the task manager are reachable
6175

6276
# -------------------------------------------
@@ -67,7 +81,12 @@ async def upgrade_cmd(
6781
# Upgrade Infrahub Database and Schema
6882
# -------------------------------------------
6983

70-
if not await migrate_database(db=dbdriver, initialize=False, check=check):
84+
migrations = await detect_migration_to_run(current_graph_version=root_node.graph_version)
85+
if check:
86+
await dbdriver.close()
87+
return
88+
89+
if not await migrate_database(db=dbdriver, initialize=False, migrations=migrations):
7190
# A migration failed, stop the upgrade process
7291
rprint("Upgrade cancelled due to migration failure.")
7392
await dbdriver.close()
@@ -91,6 +110,12 @@ async def upgrade_cmd(
91110
await setup_deployments(client=client)
92111
await trigger_configure_all()
93112

113+
# -------------------------------------------
114+
# Perform branch rebase and apply migrations to them
115+
# -------------------------------------------
116+
if rebase_branches:
117+
await trigger_rebase_branches(db=dbdriver)
118+
94119
await dbdriver.close()
95120

96121

backend/infrahub/core/branch/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@
44
class BranchStatus(InfrahubStringEnum):
55
OPEN = "OPEN"
66
NEED_REBASE = "NEED_REBASE"
7+
NEED_UPGRADE_REBASE = "NEED_UPGRADE_REBASE"
78
DELETING = "DELETING"

backend/infrahub/core/branch/models.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@
66
from pydantic import Field, field_validator
77

88
from infrahub.core.branch.enums import BranchStatus
9-
from infrahub.core.constants import (
10-
GLOBAL_BRANCH_NAME,
11-
)
9+
from infrahub.core.constants import GLOBAL_BRANCH_NAME
10+
from infrahub.core.graph import GRAPH_VERSION
1211
from infrahub.core.models import SchemaBranchHash # noqa: TC001
1312
from infrahub.core.node.standard import StandardNode
1413
from infrahub.core.query import QueryType
@@ -46,6 +45,7 @@ class Branch(StandardNode):
4645
is_isolated: bool = True
4746
schema_changed_at: Optional[str] = None
4847
schema_hash: Optional[SchemaBranchHash] = None
48+
graph_version: int | None = None
4949

5050
_exclude_attrs: list[str] = ["id", "uuid", "owner"]
5151

@@ -261,6 +261,10 @@ def get_branches_and_times_for_range(
261261

262262
return start, end
263263

264+
async def create(self, db: InfrahubDatabase) -> bool:
265+
self.graph_version = GRAPH_VERSION
266+
return await super().create(db=db)
267+
264268
async def delete(self, db: InfrahubDatabase) -> None:
265269
if self.is_default:
266270
raise ValidationError(f"Unable to delete {self.name} it is the default branch.")

0 commit comments

Comments
 (0)