diff --git a/backend/infrahub/cli/db.py b/backend/infrahub/cli/db.py index cd7c196dc3..9c1f0b4bad 100644 --- a/backend/infrahub/cli/db.py +++ b/backend/infrahub/cli/db.py @@ -18,7 +18,12 @@ from rich.table import Table from infrahub import config +from infrahub.auth import AccountSession, AuthType +from infrahub.context import InfrahubContext from infrahub.core import registry +from infrahub.core.branch import Branch +from infrahub.core.branch.tasks import rebase_branch +from infrahub.core.constants import GLOBAL_BRANCH_NAME from infrahub.core.graph import GRAPH_VERSION from infrahub.core.graph.constraints import ConstraintManagerBase, ConstraintManagerMemgraph, ConstraintManagerNeo4j from infrahub.core.graph.index import node_indexes, rel_indexes @@ -30,10 +35,8 @@ GraphRelationshipIsPartOf, GraphRelationshipProperties, ) -from infrahub.core.initialization import ( - get_root_node, - initialize_registry, -) +from infrahub.core.initialization import get_root_node, initialize_registry +from infrahub.core.migrations.exceptions import MigrationFailureError from infrahub.core.migrations.graph import get_graph_migrations, get_migration_by_number from infrahub.core.migrations.schema.models import SchemaApplyMigrationData from infrahub.core.migrations.schema.tasks import schema_apply_migrations @@ -45,6 +48,7 @@ from infrahub.database import DatabaseType from infrahub.database.memgraph import IndexManagerMemgraph from infrahub.database.neo4j import IndexManagerNeo4j +from infrahub.exceptions import ValidationError from .constants import ERROR_BADGE, FAILED_BADGE, SUCCESS_BADGE from .db_commands.check_inheritance import check_inheritance @@ -59,7 +63,7 @@ def get_timestamp_string() -> str: if TYPE_CHECKING: from infrahub.cli.context import CliContext - from infrahub.core.migrations.shared import ArbitraryMigration, GraphMigration, InternalSchemaMigration + from infrahub.core.migrations.shared import MigrationTypes from infrahub.database import InfrahubDatabase from infrahub.database.index import IndexManagerBase @@ -105,7 +109,15 @@ async def migrate_cmd( context: CliContext = ctx.obj dbdriver = await context.init_db(retry=1) - await migrate_database(db=dbdriver, initialize=True, check=check, migration_number=migration_number) + root_node = await get_root_node(db=dbdriver) + migrations = await detect_migration_to_run( + current_graph_version=root_node.graph_version, migration_number=migration_number + ) + + if check or not migrations: + return + + await migrate_database(db=dbdriver, migrations=migrations, initialize=True) await dbdriver.close() @@ -268,8 +280,38 @@ async def index( await dbdriver.close() +async def detect_migration_to_run( + current_graph_version: int, migration_number: int | str | None = None +) -> Sequence[MigrationTypes]: + """Return a sequence of migrations to apply to upgrade the database.""" + rprint("Checking current state of the database") + migrations: list[MigrationTypes] = [] + + if migration_number: + migration = get_migration_by_number(migration_number) + migrations.append(migration) + if current_graph_version > migration.minimum_version: + rprint( + f"Migration {migration_number} already applied. To apply again, run the command without the --check flag." + ) + return [] + rprint( + f"Migration {migration_number} needs to be applied. Run `infrahub db migrate` to apply all outstanding migrations." + ) + else: + migrations.extend(await get_graph_migrations(current_graph_version=current_graph_version)) + if not migrations: + rprint(f"Database up-to-date (v{current_graph_version}), no migration to execute.") + return [] + + rprint( + f"Database needs to be updated (v{current_graph_version} -> v{GRAPH_VERSION}), {len(migrations)} migrations pending" + ) + return migrations + + async def migrate_database( - db: InfrahubDatabase, initialize: bool = False, check: bool = False, migration_number: int | str | None = None + db: InfrahubDatabase, migrations: Sequence[MigrationTypes], initialize: bool = False ) -> bool: """Apply the latest migrations to the database, this function will print the status directly in the console. @@ -277,40 +319,16 @@ async def migrate_database( Args: db: The database object. - check: If True, the function will only check the status of the database and not apply the migrations. Defaults to False. - migration_number: If provided, the function will only apply the migration with the given number. Defaults to None. + migrations: Sequence of migrations to apply. + initialize: Whether to initialize the registry before running migrations. """ - rprint("Checking current state of the Database") + if not migrations: + return True if initialize: await initialize_registry(db=db) root_node = await get_root_node(db=db) - if migration_number: - migration = get_migration_by_number(migration_number) - migrations: Sequence[GraphMigration | InternalSchemaMigration | ArbitraryMigration] = [migration] - if check: - if root_node.graph_version > migration.minimum_version: - rprint( - f"Migration {migration_number} already applied. To apply again, run the command without the --check flag." - ) - return True - rprint( - f"Migration {migration_number} needs to be applied. Run `infrahub db migrate` to apply all outstanding migrations." - ) - return False - else: - migrations = await get_graph_migrations(root=root_node) - if not migrations: - rprint(f"Database up-to-date (v{root_node.graph_version}), no migration to execute.") - return True - - rprint( - f"Database needs to be updated (v{root_node.graph_version} -> v{GRAPH_VERSION}), {len(migrations)} migrations pending" - ) - - if check: - return True for migration in migrations: execution_result = await migration.execute(db=db) @@ -335,6 +353,36 @@ async def migrate_database( return True +async def trigger_rebase_branches(db: InfrahubDatabase) -> None: + """Trigger rebase of non-default branches, also triggering migrations in the process.""" + branches = [b for b in await Branch.get_list(db=db) if b.name not in [registry.default_branch, GLOBAL_BRANCH_NAME]] + if not branches: + return + + rprint(f"Planning rebase and migrations for {len(branches)} branches: {', '.join([b.name for b in branches])}") + + for branch in branches: + if branch.graph_version == GRAPH_VERSION: + rprint( + f"Ignoring branch rebase and migrations for '{branch.name}' (ID: {branch.uuid}), it is already up-to-date" + ) + continue + + rprint(f"Rebasing branch '{branch.name}' (ID: {branch.uuid})...", end="") + try: + await registry.schema.load_schema(db=db, branch=branch) + await rebase_branch( + branch=branch.name, + context=InfrahubContext.init( + branch=branch, account=AccountSession(auth_type=AuthType.NONE, authenticated=False, account_id="") + ), + send_events=False, + ) + rprint(SUCCESS_BADGE) + except (ValidationError, MigrationFailureError): + rprint(FAILED_BADGE) + + async def initialize_internal_schema() -> None: registry.schema = SchemaManager() schema = SchemaRoot(**internal_schema) diff --git a/backend/infrahub/cli/upgrade.py b/backend/infrahub/cli/upgrade.py index cad4ff6d07..6a65c6c815 100644 --- a/backend/infrahub/cli/upgrade.py +++ b/backend/infrahub/cli/upgrade.py @@ -11,7 +11,12 @@ from rich import print as rprint from infrahub import config -from infrahub.core.initialization import create_anonymous_role, create_default_account_groups, initialize_registry +from infrahub.core.initialization import ( + create_anonymous_role, + create_default_account_groups, + get_root_node, + initialize_registry, +) from infrahub.core.manager import NodeManager from infrahub.core.protocols import CoreAccount, CoreObjectPermission from infrahub.dependencies.registry import build_component_registry @@ -26,7 +31,13 @@ setup_worker_pools, ) -from .db import initialize_internal_schema, migrate_database, update_core_schema +from .db import ( + detect_migration_to_run, + initialize_internal_schema, + migrate_database, + trigger_rebase_branches, + update_core_schema, +) if TYPE_CHECKING: from infrahub.cli.context import CliContext @@ -40,6 +51,7 @@ async def upgrade_cmd( ctx: typer.Context, config_file: str = typer.Argument("infrahub.toml", envvar="INFRAHUB_CONFIG"), check: bool = typer.Option(False, help="Check the state of the system without upgrading."), + rebase_branches: bool = typer.Option(False, help="Rebase and apply migrations to branches if required."), ) -> None: """Upgrade Infrahub to the latest version.""" @@ -57,6 +69,8 @@ async def upgrade_cmd( build_component_registry() + root_node = await get_root_node(db=dbdriver) + # NOTE add step to validate if the database and the task manager are reachable # ------------------------------------------- @@ -67,7 +81,12 @@ async def upgrade_cmd( # Upgrade Infrahub Database and Schema # ------------------------------------------- - if not await migrate_database(db=dbdriver, initialize=False, check=check): + migrations = await detect_migration_to_run(current_graph_version=root_node.graph_version) + if check: + await dbdriver.close() + return + + if not await migrate_database(db=dbdriver, initialize=False, migrations=migrations): # A migration failed, stop the upgrade process rprint("Upgrade cancelled due to migration failure.") await dbdriver.close() @@ -91,6 +110,12 @@ async def upgrade_cmd( await setup_deployments(client=client) await trigger_configure_all() + # ------------------------------------------- + # Perform branch rebase and apply migrations to them + # ------------------------------------------- + if rebase_branches: + await trigger_rebase_branches(db=dbdriver) + await dbdriver.close() diff --git a/backend/infrahub/core/branch/enums.py b/backend/infrahub/core/branch/enums.py index 4aff241f05..5bc0e4e165 100644 --- a/backend/infrahub/core/branch/enums.py +++ b/backend/infrahub/core/branch/enums.py @@ -4,4 +4,5 @@ class BranchStatus(InfrahubStringEnum): OPEN = "OPEN" NEED_REBASE = "NEED_REBASE" + NEED_UPGRADE_REBASE = "NEED_UPGRADE_REBASE" DELETING = "DELETING" diff --git a/backend/infrahub/core/branch/models.py b/backend/infrahub/core/branch/models.py index 9456e31e4b..86cfb341b5 100644 --- a/backend/infrahub/core/branch/models.py +++ b/backend/infrahub/core/branch/models.py @@ -6,9 +6,8 @@ from pydantic import Field, field_validator from infrahub.core.branch.enums import BranchStatus -from infrahub.core.constants import ( - GLOBAL_BRANCH_NAME, -) +from infrahub.core.constants import GLOBAL_BRANCH_NAME +from infrahub.core.graph import GRAPH_VERSION from infrahub.core.models import SchemaBranchHash # noqa: TC001 from infrahub.core.node.standard import StandardNode from infrahub.core.query import QueryType @@ -46,6 +45,7 @@ class Branch(StandardNode): is_isolated: bool = True schema_changed_at: Optional[str] = None schema_hash: Optional[SchemaBranchHash] = None + graph_version: int | None = None _exclude_attrs: list[str] = ["id", "uuid", "owner"] @@ -261,6 +261,10 @@ def get_branches_and_times_for_range( return start, end + async def create(self, db: InfrahubDatabase) -> bool: + self.graph_version = GRAPH_VERSION + return await super().create(db=db) + async def delete(self, db: InfrahubDatabase) -> None: if self.is_default: raise ValidationError(f"Unable to delete {self.name} it is the default branch.") diff --git a/backend/infrahub/core/branch/tasks.py b/backend/infrahub/core/branch/tasks.py index 5843b92d75..d8e80abced 100644 --- a/backend/infrahub/core/branch/tasks.py +++ b/backend/infrahub/core/branch/tasks.py @@ -12,6 +12,7 @@ from infrahub.context import InfrahubContext # noqa: TC001 needed for prefect flow from infrahub.core import registry from infrahub.core.branch import Branch +from infrahub.core.branch.enums import BranchStatus from infrahub.core.changelog.diff import DiffChangelogCollector, MigrationTracker from infrahub.core.constants import MutationAction from infrahub.core.diff.coordinator import DiffCoordinator @@ -21,7 +22,10 @@ from infrahub.core.diff.model.path import BranchTrackingId, EnrichedDiffRoot, EnrichedDiffRootMetadata from infrahub.core.diff.models import RequestDiffUpdate from infrahub.core.diff.repository.repository import DiffRepository +from infrahub.core.graph import GRAPH_VERSION from infrahub.core.merge import BranchMerger +from infrahub.core.migrations.exceptions import MigrationFailureError +from infrahub.core.migrations.runner import MigrationRunner from infrahub.core.migrations.schema.models import SchemaApplyMigrationData from infrahub.core.migrations.schema.tasks import schema_apply_migrations from infrahub.core.timestamp import Timestamp @@ -29,7 +33,13 @@ from infrahub.core.validators.models.validate_migration import SchemaValidateMigrationData from infrahub.core.validators.tasks import schema_validate_migrations from infrahub.dependencies.registry import get_component_registry -from infrahub.events.branch_action import BranchCreatedEvent, BranchDeletedEvent, BranchMergedEvent, BranchRebasedEvent +from infrahub.events.branch_action import ( + BranchCreatedEvent, + BranchDeletedEvent, + BranchMergedEvent, + BranchMigratedEvent, + BranchRebasedEvent, +) from infrahub.events.models import EventMeta, InfrahubEvent from infrahub.events.node_action import get_node_event from infrahub.exceptions import BranchNotFoundError, ValidationError @@ -49,8 +59,57 @@ from infrahub.workflows.utils import add_tags +@flow(name="branch-migrate", flow_run_name="Apply migrations to branch {branch}") +async def migrate_branch(branch: str, context: InfrahubContext, send_events: bool = True) -> None: + await add_tags(branches=[branch]) + + database = await get_database() + async with database.start_session() as db: + log = get_run_logger() + + obj = await Branch.get_by_name(db=db, name=branch) + + if obj.graph_version == GRAPH_VERSION: + log.info(f"Branch '{obj.name}' has graph version {obj.graph_version}, no migrations to apply") + return + + migration_runner = MigrationRunner(branch=obj) + if not migration_runner.has_migrations(): + log.info(f"No migrations detected for branch '{obj.name}'") + obj.graph_version = GRAPH_VERSION + await obj.save(db=db) + return + + # Branch status will remain as so if the migration process fails + # This will help user to know that a branch is in an invalid state to be used properly and that actions need to be taken + if obj.status != BranchStatus.NEED_UPGRADE_REBASE: + obj.status = BranchStatus.NEED_UPGRADE_REBASE + await obj.save(db=db) + + try: + log.info(f"Running migrations for branch '{obj.name}'") + await migration_runner.run(db=db) + except MigrationFailureError as exc: + log.error(f"Failed to run migrations for branch '{obj.name}': {exc.errors}") + raise + + if obj.status == BranchStatus.NEED_UPGRADE_REBASE: + obj.status = BranchStatus.OPEN + obj.graph_version = GRAPH_VERSION + await obj.save(db=db) + + if send_events: + event_service = await get_event_service() + await event_service.send( + BranchMigratedEvent( + branch_name=obj.name, branch_id=str(obj.uuid), meta=EventMeta(branch=obj, context=context) + ) + ) + + @flow(name="branch-rebase", flow_run_name="Rebase branch {branch}") -async def rebase_branch(branch: str, context: InfrahubContext) -> None: # noqa: PLR0915 +async def rebase_branch(branch: str, context: InfrahubContext, send_events: bool = True) -> None: # noqa: PLR0915 + workflow = get_workflow() database = await get_database() async with database.start_session() as db: log = get_run_logger() @@ -69,7 +128,7 @@ async def rebase_branch(branch: str, context: InfrahubContext) -> None: # noqa: diff_repository=diff_repository, source_branch=obj, diff_locker=DiffLocker(), - workflow=get_workflow(), + workflow=workflow, ) enriched_diff_metadata = await diff_coordinator.update_branch_diff(base_branch=base_branch, diff_branch=obj) @@ -156,15 +215,17 @@ async def rebase_branch(branch: str, context: InfrahubContext) -> None: # noqa: target_branch_name=registry.default_branch, ) if ipam_node_details: - await get_workflow().submit_workflow( + await workflow.submit_workflow( workflow=IPAM_RECONCILIATION, context=context, parameters={"branch": obj.name, "ipam_node_details": ipam_node_details}, ) - await get_workflow().submit_workflow( - workflow=DIFF_REFRESH_ALL, context=context, parameters={"branch_name": obj.name} - ) + await migrate_branch(branch=branch, context=context, send_events=send_events) + await workflow.submit_workflow(workflow=DIFF_REFRESH_ALL, context=context, parameters={"branch_name": obj.name}) + + if not send_events: + return # ------------------------------------------------------------- # Generate an event to indicate that a branch has been rebased diff --git a/backend/infrahub/core/constants/__init__.py b/backend/infrahub/core/constants/__init__.py index b12a034fff..524ca1ce59 100644 --- a/backend/infrahub/core/constants/__init__.py +++ b/backend/infrahub/core/constants/__init__.py @@ -51,6 +51,7 @@ class EventType(InfrahubStringEnum): BRANCH_CREATED = f"{EVENT_NAMESPACE}.branch.created" BRANCH_DELETED = f"{EVENT_NAMESPACE}.branch.deleted" BRANCH_MERGED = f"{EVENT_NAMESPACE}.branch.merged" + BRANCH_MIGRATED = f"{EVENT_NAMESPACE}.branch.migrated" BRANCH_REBASED = f"{EVENT_NAMESPACE}.branch.rebased" SCHEMA_UPDATED = f"{EVENT_NAMESPACE}.schema.updated" diff --git a/backend/infrahub/core/migrations/exceptions.py b/backend/infrahub/core/migrations/exceptions.py new file mode 100644 index 0000000000..6edeecb7ec --- /dev/null +++ b/backend/infrahub/core/migrations/exceptions.py @@ -0,0 +1,4 @@ +class MigrationFailureError(Exception): + def __init__(self, errors: list[str]) -> None: + super().__init__() + self.errors = errors diff --git a/backend/infrahub/core/migrations/graph/__init__.py b/backend/infrahub/core/migrations/graph/__init__.py index c7e661973d..8d5ce89654 100644 --- a/backend/infrahub/core/migrations/graph/__init__.py +++ b/backend/infrahub/core/migrations/graph/__init__.py @@ -47,11 +47,10 @@ from .m043_backfill_hfid_display_label_in_db import Migration043 if TYPE_CHECKING: - from infrahub.core.root import Root + from ..shared import MigrationTypes - from ..shared import ArbitraryMigration, GraphMigration, InternalSchemaMigration -MIGRATIONS: list[type[GraphMigration | InternalSchemaMigration | ArbitraryMigration]] = [ +MIGRATIONS: list[type[MigrationTypes]] = [ Migration001, Migration002, Migration003, @@ -98,22 +97,18 @@ ] -async def get_graph_migrations( - root: Root, -) -> Sequence[GraphMigration | InternalSchemaMigration | ArbitraryMigration]: +async def get_graph_migrations(current_graph_version: int) -> Sequence[MigrationTypes]: applicable_migrations = [] for migration_class in MIGRATIONS: migration = migration_class.init() - if root.graph_version > migration.minimum_version: + if current_graph_version > migration.minimum_version: continue applicable_migrations.append(migration) return applicable_migrations -def get_migration_by_number( - migration_number: int | str, -) -> GraphMigration | InternalSchemaMigration | ArbitraryMigration: +def get_migration_by_number(migration_number: int | str) -> MigrationTypes: # Convert to string and pad with zeros if needed try: num = int(migration_number) diff --git a/backend/infrahub/core/migrations/graph/m043_backfill_hfid_display_label_in_db.py b/backend/infrahub/core/migrations/graph/m043_backfill_hfid_display_label_in_db.py index 046ef08016..12d79d8a78 100644 --- a/backend/infrahub/core/migrations/graph/m043_backfill_hfid_display_label_in_db.py +++ b/backend/infrahub/core/migrations/graph/m043_backfill_hfid_display_label_in_db.py @@ -18,7 +18,7 @@ from infrahub.exceptions import InitializationError from infrahub.types import is_large_attribute_type -from ..shared import ArbitraryMigration +from ..shared import MigrationRequiringRebase if TYPE_CHECKING: from infrahub.core.schema.basenode_schema import SchemaAttributePath @@ -600,7 +600,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> No self.add_to_query(set_value_query) -class Migration043(ArbitraryMigration): +class Migration043(MigrationRequiringRebase): """ Backfill `human_friendly_id` and `display_label` attributes for nodes with schemas that define them. """ diff --git a/backend/infrahub/core/migrations/runner.py b/backend/infrahub/core/migrations/runner.py new file mode 100644 index 0000000000..da62c44650 --- /dev/null +++ b/backend/infrahub/core/migrations/runner.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Sequence + +from infrahub.core import registry +from infrahub.core.constants import GLOBAL_BRANCH_NAME +from infrahub.core.migrations.graph import MIGRATIONS + +from .exceptions import MigrationFailureError +from .shared import MigrationRequiringRebase + +if TYPE_CHECKING: + from infrahub.core.branch import Branch + from infrahub.database import InfrahubDatabase + + +class MigrationRunner: + def __init__(self, branch: Branch) -> None: + if branch.name in (registry.default_branch, GLOBAL_BRANCH_NAME): + raise ValueError("MigrationRunner cannot be used to apply migration on default branches") + + self.branch = branch + self.applicable_migrations = self._get_applicable_migrations() + + def _get_applicable_migrations(self) -> Sequence[MigrationRequiringRebase]: + applicable_migrations = [] + for migration_class in [m for m in MIGRATIONS if issubclass(m, MigrationRequiringRebase)]: + migration = migration_class.init() + if self.branch.graph_version and self.branch.graph_version > migration.minimum_version: + continue + applicable_migrations.append(migration) + + return applicable_migrations + + def has_migrations(self) -> bool: + return bool(self.applicable_migrations) + + async def run(self, db: InfrahubDatabase) -> None: + if not self.has_migrations(): + return + + for migration in self.applicable_migrations: + execution_result = await migration.execute_against_branch(db=db, branch=self.branch) + validation_result = None + + if execution_result.success: + validation_result = await migration.validate_migration(db=db) + + if not execution_result.success or (validation_result and not validation_result.success): + if execution_result.errors: + raise MigrationFailureError(errors=execution_result.errors) + + if validation_result and not validation_result.success: + raise MigrationFailureError(errors=validation_result.errors) diff --git a/backend/infrahub/core/migrations/shared.py b/backend/infrahub/core/migrations/shared.py index 544f4c4462..ab1548a651 100644 --- a/backend/infrahub/core/migrations/shared.py +++ b/backend/infrahub/core/migrations/shared.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Sequence +from typing import TYPE_CHECKING, Any, Sequence, TypeAlias from pydantic import BaseModel, ConfigDict, Field from typing_extensions import Self @@ -8,13 +8,7 @@ from infrahub.core import registry from infrahub.core.path import SchemaPath # noqa: TC001 from infrahub.core.query import Query # noqa: TC001 -from infrahub.core.schema import ( - AttributeSchema, - MainSchemaTypes, - RelationshipSchema, - SchemaRoot, - internal_schema, -) +from infrahub.core.schema import AttributeSchema, MainSchemaTypes, RelationshipSchema, SchemaRoot, internal_schema from infrahub.core.timestamp import Timestamp from .query import MigrationBaseQuery # noqa: TC001 @@ -230,3 +224,27 @@ async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: async def execute(self, db: InfrahubDatabase) -> MigrationResult: raise NotImplementedError() + + +class MigrationRequiringRebase(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + name: str = Field(..., description="Name of the migration") + minimum_version: int = Field(..., description="Minimum version of the graph to execute this migration") + + @classmethod + def init(cls, **kwargs: dict[str, Any]) -> Self: + return cls(**kwargs) # type: ignore[arg-type] + + async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: + raise NotImplementedError() + + async def execute_against_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult: + """Method that will be run against non-default branches, it assumes that the branches have been rebased.""" + raise NotImplementedError() + + async def execute(self, db: InfrahubDatabase) -> MigrationResult: + """Method that will be run against the default branch.""" + raise NotImplementedError() + + +MigrationTypes: TypeAlias = GraphMigration | InternalSchemaMigration | ArbitraryMigration | MigrationRequiringRebase diff --git a/backend/infrahub/events/branch_action.py b/backend/infrahub/events/branch_action.py index 8c43bb454c..9d40efacf1 100644 --- a/backend/infrahub/events/branch_action.py +++ b/backend/infrahub/events/branch_action.py @@ -109,7 +109,7 @@ class BranchRebasedEvent(InfrahubEvent): event_name: ClassVar[str] = f"{EVENT_NAMESPACE}.branch.rebased" - branch_id: str = Field(..., description="The ID of the mutated node") + branch_id: str = Field(..., description="The ID of the branch") branch_name: str = Field(..., description="The name of the branch") def get_resource(self) -> dict[str, str]: @@ -128,3 +128,29 @@ def get_messages(self) -> list[InfrahubMessage]: RefreshRegistryRebasedBranch(branch=self.branch_name), ] return events + + +class BranchMigratedEvent(InfrahubEvent): + """Event generated when a branch has been migrated""" + + event_name: ClassVar[str] = f"{EVENT_NAMESPACE}.branch.migrated" + + branch_id: str = Field(..., description="The ID of the branch") + branch_name: str = Field(..., description="The name of the branch") + + def get_resource(self) -> dict[str, str]: + return { + "prefect.resource.id": f"infrahub.branch.{self.branch_name}", + "infrahub.branch.id": self.branch_id, + "infrahub.branch.name": self.branch_name, + } + + def get_messages(self) -> list[InfrahubMessage]: + events: list[InfrahubMessage] = [ + # EventBranchMigrated( + # branch=self.branch, + # meta=self.get_message_meta(), + # ), + RefreshRegistryRebasedBranch(branch=self.branch_name), + ] + return events diff --git a/backend/infrahub/graphql/types/branch.py b/backend/infrahub/graphql/types/branch.py index 86c79f2b62..285e02b026 100644 --- a/backend/infrahub/graphql/types/branch.py +++ b/backend/infrahub/graphql/types/branch.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any -from graphene import Boolean, Field, String +from graphene import Boolean, Field, Int, String from infrahub.core.branch import Branch from infrahub.core.constants import GLOBAL_BRANCH_NAME @@ -21,6 +21,7 @@ class BranchType(InfrahubObjectType): origin_branch = String(required=False) branched_from = String(required=False) status = InfrahubBranchStatus(required=True) + graph_version = Int(required=False) created_at = String(required=False) sync_with_git = Boolean(required=False) is_default = Boolean(required=False) diff --git a/backend/infrahub/task_manager/event.py b/backend/infrahub/task_manager/event.py index dd973fe931..6f76e532df 100644 --- a/backend/infrahub/task_manager/event.py +++ b/backend/infrahub/task_manager/event.py @@ -160,6 +160,9 @@ def _return_branch_merged(self) -> dict[str, Any]: def _return_branch_rebased(self) -> dict[str, Any]: return {"rebased_branch": self._get_branch_name_from_resource()} + def _return_branch_migrated(self) -> dict[str, Any]: + return {"migrated_branch": self._get_branch_name_from_resource()} + def _return_group_event(self) -> dict[str, Any]: members = [] ancestors = [] @@ -228,6 +231,8 @@ def _return_event_specifics(self) -> dict[str, Any]: event_specifics = self._return_branch_deleted() case "infrahub.branch.merged": event_specifics = self._return_branch_merged() + case "infrahub.branch.migrated": + event_specifics = self._return_branch_migrated() case "infrahub.branch.rebased": event_specifics = self._return_branch_rebased() case "infrahub.group.member_added" | "infrahub.group.member_removed": diff --git a/backend/infrahub/task_manager/models.py b/backend/infrahub/task_manager/models.py index 4e3320185a..3d51de0134 100644 --- a/backend/infrahub/task_manager/models.py +++ b/backend/infrahub/task_manager/models.py @@ -141,6 +141,13 @@ def add_event_type_filter( if branches: self.resource = EventResourceFilter(labels=ResourceSpecification({"infrahub.branch.name": branches})) + if branch_migrated := event_type_filter.get("branch_migrated"): + branches = branch_migrated.get("branches") or [] + if "infrahub.branch.created" not in event_type: + event_type.append("infrahub.branch.migrated") + if branches: + self.resource = EventResourceFilter(labels=ResourceSpecification({"infrahub.branch.name": branches})) + if branch_rebased := event_type_filter.get("branch_rebased"): branches = branch_rebased.get("branches") or [] if "infrahub.branch.created" not in event_type: diff --git a/backend/infrahub/workflows/catalogue.py b/backend/infrahub/workflows/catalogue.py index 0fde0d01e1..f1253ee69c 100644 --- a/backend/infrahub/workflows/catalogue.py +++ b/backend/infrahub/workflows/catalogue.py @@ -230,6 +230,13 @@ function="rebase_branch", tags=[WorkflowTag.DATABASE_CHANGE], ) +BRANCH_MIGRATE = WorkflowDefinition( + name="branch-migrate", + type=WorkflowType.CORE, + module="infrahub.core.branch.tasks", + function="migrate_branch", + tags=[WorkflowTag.DATABASE_CHANGE], +) BRANCH_CREATE = WorkflowDefinition( name="create-branch", @@ -641,6 +648,7 @@ BRANCH_MERGED, BRANCH_MERGE_MUTATION, BRANCH_MERGE_POST_PROCESS, + BRANCH_MIGRATE, BRANCH_REBASE, BRANCH_VALIDATE, CLEAN_UP_DEADLOCKS, diff --git a/backend/tests/unit/core/migrations/test_runner.py b/backend/tests/unit/core/migrations/test_runner.py new file mode 100644 index 0000000000..bc6f417db7 --- /dev/null +++ b/backend/tests/unit/core/migrations/test_runner.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from infrahub.core.branch.models import Branch +from infrahub.core.constants import GLOBAL_BRANCH_NAME +from infrahub.core.initialization import create_branch +from infrahub.core.migrations.runner import MigrationRunner + +if TYPE_CHECKING: + from infrahub.database import InfrahubDatabase + + +def test_instantiation(default_branch: Branch) -> None: + assert MigrationRunner(branch=Branch(name="foo")) + + with pytest.raises(ValueError): + MigrationRunner(branch=default_branch) + + with pytest.raises(ValueError): + MigrationRunner(branch=Branch(name=GLOBAL_BRANCH_NAME)) + + +async def test_has_migrations(default_branch: Branch, db: InfrahubDatabase) -> None: + branch = await create_branch(db=db, branch_name="foo") + + runner = MigrationRunner(branch=branch) + assert not runner.has_migrations() + + branch.graph_version = None + runner = MigrationRunner(branch=branch) + await branch.save(db=db) + assert runner.has_migrations() + + branch.graph_version = 40 + runner = MigrationRunner(branch=branch) + await branch.save(db=db) + assert runner.has_migrations() + + +async def test_applicable_migrations(default_branch: Branch, db: InfrahubDatabase) -> None: + branch = await create_branch(db=db, branch_name="foo") + + runner = MigrationRunner(branch=branch) + assert not runner.applicable_migrations + + branch.graph_version = None + runner = MigrationRunner(branch=branch) + await branch.save(db=db) + assert runner.applicable_migrations + assert [m.name for m in runner.applicable_migrations][0] == "043_backfill_hfid_display_label_in_db" + + branch.graph_version = 40 + runner = MigrationRunner(branch=branch) + await branch.save(db=db) + assert runner.applicable_migrations + assert [m.name for m in runner.applicable_migrations][0] == "043_backfill_hfid_display_label_in_db" diff --git a/backend/tests/unit/workflows/test_models.py b/backend/tests/unit/workflows/test_models.py index 6d1088768c..8452f500b5 100644 --- a/backend/tests/unit/workflows/test_models.py +++ b/backend/tests/unit/workflows/test_models.py @@ -6,4 +6,5 @@ def test_get_parameters(): assert BRANCH_REBASE.get_parameters() == { "branch": WorkflowParameter(name="branch", type="str", required=True), "context": WorkflowParameter(name="context", type="InfrahubContext", required=True), + "send_events": WorkflowParameter(name="send_events", type="bool", required=False), } diff --git a/changelog/+upgrade-process.changed.md b/changelog/+upgrade-process.changed.md new file mode 100644 index 0000000000..e7230cd7bb --- /dev/null +++ b/changelog/+upgrade-process.changed.md @@ -0,0 +1 @@ +Add a new flow to apply migrations inside branches that will also set a `graph_version` property on branches after running. The upgrade process (via the provided command) is also updated with a new `--rebase-branches` flag to trigger branch rebases as part of the upgrade process. Also add the `NEED_UPGRADE_REBASE` branch status to identify branches that need the migrations flow to run. \ No newline at end of file diff --git a/docs/docs/reference/infrahub-cli/infrahub-upgrade.mdx b/docs/docs/reference/infrahub-cli/infrahub-upgrade.mdx index 75eed71145..6ae091da7c 100644 --- a/docs/docs/reference/infrahub-cli/infrahub-upgrade.mdx +++ b/docs/docs/reference/infrahub-cli/infrahub-upgrade.mdx @@ -15,6 +15,7 @@ $ infrahub upgrade [OPTIONS] [CONFIG_FILE] **Options**: * `--check / --no-check`: Check the state of the system without upgrading. [default: no-check] +* `--rebase-branches / --no-rebase-branches`: Rebase and apply migrations to branches if required. [default: no-rebase-branches] * `--install-completion`: Install completion for the current shell. * `--show-completion`: Show completion for the current shell, to copy it or customize the installation. * `--help`: Show this message and exit. diff --git a/docs/docs/reference/infrahub-events.mdx b/docs/docs/reference/infrahub-events.mdx index 28b2bbb825..9dcc3882d5 100644 --- a/docs/docs/reference/infrahub-events.mdx +++ b/docs/docs/reference/infrahub-events.mdx @@ -164,6 +164,32 @@ For more detailed explanations on how these events are used within Infrahub, see | **proposed_change_id** | The ID of the proposed change that merged this branch if applicable | +### Branch Migrated Event + + +**Type**: infrahub.branch.migrated +**Description**: Event generated when a branch has been migrated + +**Uses node_kind filter for webhooks**: `false` + + + +| Key | Description | +|-----|-------------| +| **meta.branch** | The branch on which originate this event | +| **meta.request_id** | N/A | +| **meta.account_id** | The ID of the account triggering this event | +| **meta.initiator_id** | The worker identity of the initial sender of this message | +| **meta.context** | The context used when originating this event | +| **meta.level** | N/A | +| **meta.has_children** | Indicates if this event might potentially have child events under it. | +| **meta.id** | UUID of the event | +| **meta.parent** | The UUID of the parent event if applicable | +| **meta.ancestors** | Any event used to trigger this event | +| **branch_id** | The ID of the branch | +| **branch_name** | The name of the branch | + + ### Branch Rebased Event @@ -186,7 +212,7 @@ For more detailed explanations on how these events are used within Infrahub, see | **meta.id** | UUID of the event | | **meta.parent** | The UUID of the parent event if applicable | | **meta.ancestors** | Any event used to trigger this event | -| **branch_id** | The ID of the mutated node | +| **branch_id** | The ID of the branch | | **branch_name** | The name of the branch | ## Group events diff --git a/schema/schema.graphql b/schema/schema.graphql index a07a15753d..36812404d0 100644 --- a/schema/schema.graphql +++ b/schema/schema.graphql @@ -144,6 +144,7 @@ type Branch { branched_from: String created_at: String description: String + graph_version: Int has_schema_changes: Boolean id: String! is_default: Boolean @@ -323,6 +324,7 @@ enum BranchRelativePermissionDecision { enum BranchStatus { DELETING NEED_REBASE + NEED_UPGRADE_REBASE OPEN }