Skip to content

Commit 6faa1ef

Browse files
authored
HFID/display label migration updates (#7514)
* WIP branch migration for HFID and display labels * start tests * more tests * handle schema update on branch * more tests * fixes and first test for migration on branch * test and fixes for update on branch * test for updating related nodes and their attributes * on-way rel tests and fixes * tests for branch-agnostic attr, rels, and attrs on peers * fix and tests for schema-level display_label/HFID change on branch * fix schema uniqueness
1 parent 4515bf6 commit 6faa1ef

File tree

6 files changed

+1309
-203
lines changed

6 files changed

+1309
-203
lines changed

backend/infrahub/core/migrations/graph/m042_create_hfid_display_label_in_db.py

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,47 @@
66
from typing_extensions import Self
77

88
from infrahub.core import registry
9+
from infrahub.core.branch import Branch
910
from infrahub.core.constants import SchemaPathType
10-
from infrahub.core.initialization import initialization
11+
from infrahub.core.initialization import get_root_node
1112
from infrahub.core.migrations.schema.node_attribute_add import NodeAttributeAddMigration
1213
from infrahub.core.migrations.shared import InternalSchemaMigration, MigrationResult
1314
from infrahub.core.path import SchemaPath
14-
from infrahub.lock import initialize_lock
15+
from infrahub.core.query import Query, QueryType
16+
from infrahub.core.schema import SchemaRoot, internal_schema
17+
from infrahub.core.schema.manager import SchemaManager
18+
from infrahub.exceptions import InitializationError
1519

1620
if TYPE_CHECKING:
21+
from infrahub.core.schema.schema_branch import SchemaBranch
1722
from infrahub.database import InfrahubDatabase
1823

1924

25+
class GetAddedNodesByKindForBranchQuery(Query):
26+
name = "get_added_nodes_by_kind_for_branch_query"
27+
type = QueryType.READ
28+
insert_return = True
29+
30+
async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
31+
self.params["branch"] = self.branch.name
32+
query = """
33+
MATCH (n:Node)-[e:IS_PART_OF {branch: $branch, status: "active"}]->(:Root)
34+
WHERE e.to IS NULL
35+
AND NOT exists((n)-[:IS_PART_OF {branch: $branch, status: "deleted"}]->(:Root))
36+
WITH n.kind AS kind, collect(n.uuid) AS node_ids
37+
"""
38+
self.return_labels = ["kind", "node_ids"]
39+
self.add_to_query(query)
40+
41+
def get_node_ids_by_kind(self) -> dict[str, list[str]]:
42+
node_ids_by_kind: dict[str, list[str]] = {}
43+
for result in self.get_results():
44+
kind = result.get_as_type(label="kind", return_type=str)
45+
node_ids: list[str] = result.get_as_type(label="node_ids", return_type=list)
46+
node_ids_by_kind[kind] = node_ids
47+
return node_ids_by_kind
48+
49+
2050
class Migration042(InternalSchemaMigration):
2151
name: str = "042_create_hfid_display_label_in_db"
2252
minimum_version: int = 41
@@ -46,15 +76,25 @@ def init(cls, **kwargs: Any) -> Self:
4676
]
4777
return cls(migrations=cls.migrations, **kwargs) # type: ignore[arg-type]
4878

79+
async def _get_or_load_schema_branch(self, db: InfrahubDatabase, branch: Branch) -> SchemaBranch:
80+
try:
81+
if registry.schema.has_schema_branch(branch.name):
82+
return registry.schema.get_schema_branch(branch.name)
83+
except InitializationError:
84+
pass
85+
schema_manager = SchemaManager()
86+
internal_schema_root = SchemaRoot(**internal_schema)
87+
schema_manager.register_schema(schema=internal_schema_root)
88+
registry.schema = schema_manager
89+
return await schema_manager.load_schema_from_db(db=db, branch=branch)
90+
4991
async def execute(self, db: InfrahubDatabase) -> MigrationResult:
5092
result = MigrationResult()
5193

52-
# load schemas from database into registry
53-
initialize_lock()
54-
await initialization(db=db)
55-
56-
default_branch = registry.get_branch_from_registry()
57-
schema_branch = await registry.schema.load_schema_from_db(db=db, branch=default_branch)
94+
root_node = await get_root_node(db=db, initialize=False)
95+
default_branch_name = root_node.default_branch
96+
default_branch = await Branch.get_by_name(db=db, name=default_branch_name)
97+
schema_branch = await self._get_or_load_schema_branch(db=db, branch=default_branch)
5898

5999
migrations = list(self.migrations)
60100

@@ -93,5 +133,54 @@ async def execute(self, db: InfrahubDatabase) -> MigrationResult:
93133

94134
return result
95135

136+
async def execute_against_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult:
137+
result = MigrationResult()
138+
139+
schema_branch = await registry.schema.load_schema_from_db(db=db, branch=branch)
140+
141+
migrations = []
142+
get_added_nodes_by_kind_for_branch_query = await GetAddedNodesByKindForBranchQuery.init(db=db, branch=branch)
143+
await get_added_nodes_by_kind_for_branch_query.execute(db=db)
144+
node_ids_by_kind = get_added_nodes_by_kind_for_branch_query.get_node_ids_by_kind()
145+
146+
for node_kind, node_ids in node_ids_by_kind.items():
147+
schema = schema_branch.get(name=node_kind, duplicate=False)
148+
migrations.extend(
149+
[
150+
NodeAttributeAddMigration(
151+
uuids=node_ids,
152+
new_node_schema=schema,
153+
previous_node_schema=schema,
154+
schema_path=SchemaPath(
155+
schema_kind=schema.kind, path_type=SchemaPathType.ATTRIBUTE, field_name="human_friendly_id"
156+
),
157+
),
158+
NodeAttributeAddMigration(
159+
uuids=node_ids,
160+
new_node_schema=schema,
161+
previous_node_schema=schema,
162+
schema_path=SchemaPath(
163+
schema_kind=schema.kind, path_type=SchemaPathType.ATTRIBUTE, field_name="display_label"
164+
),
165+
),
166+
]
167+
)
168+
169+
with Progress() as progress:
170+
update_task = progress.add_task(
171+
f"Adding HFID and display label to nodes on branch {branch.name}", total=len(migrations)
172+
)
173+
174+
for migration in migrations:
175+
try:
176+
execution_result = await migration.execute(db=db, branch=branch)
177+
result.errors.extend(execution_result.errors)
178+
progress.update(update_task, advance=1)
179+
except Exception as exc:
180+
result.errors.append(str(exc))
181+
return result
182+
183+
return result
184+
96185
async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002
97186
return MigrationResult()

0 commit comments

Comments
 (0)