|
6 | 6 | from typing_extensions import Self |
7 | 7 |
|
8 | 8 | from infrahub.core import registry |
| 9 | +from infrahub.core.branch import Branch |
9 | 10 | from infrahub.core.constants import SchemaPathType |
10 | | -from infrahub.core.initialization import initialization |
| 11 | +from infrahub.core.initialization import get_root_node |
11 | 12 | from infrahub.core.migrations.schema.node_attribute_add import NodeAttributeAddMigration |
12 | 13 | from infrahub.core.migrations.shared import InternalSchemaMigration, MigrationResult |
13 | 14 | from infrahub.core.path import SchemaPath |
14 | | -from infrahub.lock import initialize_lock |
| 15 | +from infrahub.core.query import Query, QueryType |
| 16 | +from infrahub.core.schema import SchemaRoot, internal_schema |
| 17 | +from infrahub.core.schema.manager import SchemaManager |
| 18 | +from infrahub.exceptions import InitializationError |
15 | 19 |
|
16 | 20 | if TYPE_CHECKING: |
| 21 | + from infrahub.core.schema.schema_branch import SchemaBranch |
17 | 22 | from infrahub.database import InfrahubDatabase |
18 | 23 |
|
19 | 24 |
|
| 25 | +class GetAddedNodesByKindForBranchQuery(Query): |
| 26 | + name = "get_added_nodes_by_kind_for_branch_query" |
| 27 | + type = QueryType.READ |
| 28 | + insert_return = True |
| 29 | + |
| 30 | + async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002 |
| 31 | + self.params["branch"] = self.branch.name |
| 32 | + query = """ |
| 33 | +MATCH (n:Node)-[e:IS_PART_OF {branch: $branch, status: "active"}]->(:Root) |
| 34 | +WHERE e.to IS NULL |
| 35 | +AND NOT exists((n)-[:IS_PART_OF {branch: $branch, status: "deleted"}]->(:Root)) |
| 36 | +WITH n.kind AS kind, collect(n.uuid) AS node_ids |
| 37 | + """ |
| 38 | + self.return_labels = ["kind", "node_ids"] |
| 39 | + self.add_to_query(query) |
| 40 | + |
| 41 | + def get_node_ids_by_kind(self) -> dict[str, list[str]]: |
| 42 | + node_ids_by_kind: dict[str, list[str]] = {} |
| 43 | + for result in self.get_results(): |
| 44 | + kind = result.get_as_type(label="kind", return_type=str) |
| 45 | + node_ids: list[str] = result.get_as_type(label="node_ids", return_type=list) |
| 46 | + node_ids_by_kind[kind] = node_ids |
| 47 | + return node_ids_by_kind |
| 48 | + |
| 49 | + |
20 | 50 | class Migration042(InternalSchemaMigration): |
21 | 51 | name: str = "042_create_hfid_display_label_in_db" |
22 | 52 | minimum_version: int = 41 |
@@ -46,15 +76,25 @@ def init(cls, **kwargs: Any) -> Self: |
46 | 76 | ] |
47 | 77 | return cls(migrations=cls.migrations, **kwargs) # type: ignore[arg-type] |
48 | 78 |
|
| 79 | + async def _get_or_load_schema_branch(self, db: InfrahubDatabase, branch: Branch) -> SchemaBranch: |
| 80 | + try: |
| 81 | + if registry.schema.has_schema_branch(branch.name): |
| 82 | + return registry.schema.get_schema_branch(branch.name) |
| 83 | + except InitializationError: |
| 84 | + pass |
| 85 | + schema_manager = SchemaManager() |
| 86 | + internal_schema_root = SchemaRoot(**internal_schema) |
| 87 | + schema_manager.register_schema(schema=internal_schema_root) |
| 88 | + registry.schema = schema_manager |
| 89 | + return await schema_manager.load_schema_from_db(db=db, branch=branch) |
| 90 | + |
49 | 91 | async def execute(self, db: InfrahubDatabase) -> MigrationResult: |
50 | 92 | result = MigrationResult() |
51 | 93 |
|
52 | | - # load schemas from database into registry |
53 | | - initialize_lock() |
54 | | - await initialization(db=db) |
55 | | - |
56 | | - default_branch = registry.get_branch_from_registry() |
57 | | - schema_branch = await registry.schema.load_schema_from_db(db=db, branch=default_branch) |
| 94 | + root_node = await get_root_node(db=db, initialize=False) |
| 95 | + default_branch_name = root_node.default_branch |
| 96 | + default_branch = await Branch.get_by_name(db=db, name=default_branch_name) |
| 97 | + schema_branch = await self._get_or_load_schema_branch(db=db, branch=default_branch) |
58 | 98 |
|
59 | 99 | migrations = list(self.migrations) |
60 | 100 |
|
@@ -93,5 +133,54 @@ async def execute(self, db: InfrahubDatabase) -> MigrationResult: |
93 | 133 |
|
94 | 134 | return result |
95 | 135 |
|
| 136 | + async def execute_against_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult: |
| 137 | + result = MigrationResult() |
| 138 | + |
| 139 | + schema_branch = await registry.schema.load_schema_from_db(db=db, branch=branch) |
| 140 | + |
| 141 | + migrations = [] |
| 142 | + get_added_nodes_by_kind_for_branch_query = await GetAddedNodesByKindForBranchQuery.init(db=db, branch=branch) |
| 143 | + await get_added_nodes_by_kind_for_branch_query.execute(db=db) |
| 144 | + node_ids_by_kind = get_added_nodes_by_kind_for_branch_query.get_node_ids_by_kind() |
| 145 | + |
| 146 | + for node_kind, node_ids in node_ids_by_kind.items(): |
| 147 | + schema = schema_branch.get(name=node_kind, duplicate=False) |
| 148 | + migrations.extend( |
| 149 | + [ |
| 150 | + NodeAttributeAddMigration( |
| 151 | + uuids=node_ids, |
| 152 | + new_node_schema=schema, |
| 153 | + previous_node_schema=schema, |
| 154 | + schema_path=SchemaPath( |
| 155 | + schema_kind=schema.kind, path_type=SchemaPathType.ATTRIBUTE, field_name="human_friendly_id" |
| 156 | + ), |
| 157 | + ), |
| 158 | + NodeAttributeAddMigration( |
| 159 | + uuids=node_ids, |
| 160 | + new_node_schema=schema, |
| 161 | + previous_node_schema=schema, |
| 162 | + schema_path=SchemaPath( |
| 163 | + schema_kind=schema.kind, path_type=SchemaPathType.ATTRIBUTE, field_name="display_label" |
| 164 | + ), |
| 165 | + ), |
| 166 | + ] |
| 167 | + ) |
| 168 | + |
| 169 | + with Progress() as progress: |
| 170 | + update_task = progress.add_task( |
| 171 | + f"Adding HFID and display label to nodes on branch {branch.name}", total=len(migrations) |
| 172 | + ) |
| 173 | + |
| 174 | + for migration in migrations: |
| 175 | + try: |
| 176 | + execution_result = await migration.execute(db=db, branch=branch) |
| 177 | + result.errors.extend(execution_result.errors) |
| 178 | + progress.update(update_task, advance=1) |
| 179 | + except Exception as exc: |
| 180 | + result.errors.append(str(exc)) |
| 181 | + return result |
| 182 | + |
| 183 | + return result |
| 184 | + |
96 | 185 | async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002 |
97 | 186 | return MigrationResult() |
0 commit comments