Skip to content

Commit 4f77799

Browse files
authored
update profiles migration for rebase and branch-specific execution (#7526)
* move get_or_load_schema_branch to its own module * add query to clear property on a branch that exists on default branch * update Migration041 to support rebase and execute on branches * fix get_or_load_schema_branch * make migration041 a MigrationRequiringRebase * update test_runner unit test
1 parent 821bf43 commit 4f77799

File tree

8 files changed

+200
-160
lines changed

8 files changed

+200
-160
lines changed

backend/infrahub/core/attribute.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from infrahub.core.constants import NULL_VALUE, AttributeDBNodeType, BranchSupportType, RelationshipStatus
1919
from infrahub.core.property import FlagPropertyMixin, NodePropertyData, NodePropertyMixin
2020
from infrahub.core.query.attribute import (
21+
AttributeClearNodePropertyQuery,
2122
AttributeGetQuery,
2223
AttributeUpdateFlagQuery,
2324
AttributeUpdateNodePropertyQuery,
@@ -491,6 +492,12 @@ async def _update(self, db: InfrahubDatabase, at: Timestamp | None = None) -> At
491492
)
492493
await query.execute(db=db)
493494

495+
if needs_clear:
496+
query = await AttributeClearNodePropertyQuery.init(
497+
db=db, attr=self, at=update_at, prop_name=prop_name, prop_id=database_prop_id
498+
)
499+
await query.execute(db=db)
500+
494501
# set the to time on the previously active edge
495502
rel = current_attr_result.get(f"rel_{prop_name}")
496503
if rel and rel.get("branch") == branch.name:
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from infrahub.core import registry
2+
from infrahub.core.branch import Branch
3+
from infrahub.core.schema import SchemaRoot, internal_schema
4+
from infrahub.core.schema.manager import SchemaManager
5+
from infrahub.core.schema.schema_branch import SchemaBranch
6+
from infrahub.database import InfrahubDatabase
7+
from infrahub.exceptions import InitializationError
8+
9+
10+
async def get_or_load_schema_branch(db: InfrahubDatabase, branch: Branch) -> SchemaBranch:
11+
try:
12+
if registry.schema.has_schema_branch(branch.name):
13+
return registry.schema.get_schema_branch(branch.name)
14+
except InitializationError:
15+
schema_manager = SchemaManager()
16+
registry.schema = schema_manager
17+
internal_schema_root = SchemaRoot(**internal_schema)
18+
registry.schema.register_schema(schema=internal_schema_root)
19+
return await registry.schema.load_schema_from_db(db=db, branch=branch)
Lines changed: 60 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,90 +1,75 @@
11
from __future__ import annotations
22

3-
from collections import defaultdict
43
from typing import TYPE_CHECKING, Any
54

65
from rich.console import Console
76
from rich.progress import Progress
87

98
from infrahub.core.branch.models import Branch
10-
from infrahub.core.initialization import initialization
9+
from infrahub.core.initialization import get_root_node
1110
from infrahub.core.manager import NodeManager
1211
from infrahub.core.migrations.shared import MigrationResult
1312
from infrahub.core.query import Query, QueryType
1413
from infrahub.core.timestamp import Timestamp
15-
from infrahub.lock import initialize_lock
1614
from infrahub.log import get_logger
1715
from infrahub.profiles.node_applier import NodeProfilesApplier
1816

19-
from ..shared import ArbitraryMigration
17+
from ..shared import MigrationRequiringRebase
18+
from .load_schema_branch import get_or_load_schema_branch
2019

2120
if TYPE_CHECKING:
22-
from infrahub.core.node import Node
2321
from infrahub.database import InfrahubDatabase
2422

2523
log = get_logger()
2624

2725

28-
class GetProfilesByBranchQuery(Query):
26+
class GetUpdatedProfilesForBranchQuery(Query):
2927
"""
30-
Get CoreProfile UUIDs by which branches they have attribute updates on
28+
Get CoreProfile UUIDs with updated attributes on this branch
3129
"""
3230

3331
name = "get_profiles_by_branch"
3432
type = QueryType.READ
35-
insert_return = False
3633

3734
async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
35+
self.params["branch"] = self.branch.name
3836
query = """
3937
MATCH (profile:CoreProfile)-[:HAS_ATTRIBUTE]->(attr:Attribute)-[e:HAS_VALUE]->(:AttributeValue)
40-
WITH DISTINCT profile.uuid AS profile_uuid, e.branch AS branch
41-
RETURN profile_uuid, collect(branch) AS branches
38+
WHERE e.branch = $branch
4239
"""
4340
self.add_to_query(query)
44-
self.return_labels = ["profile_uuid", "branches"]
41+
self.return_labels = ["profile.uuid AS profile_uuid"]
4542

46-
def get_profile_ids_by_branch(self) -> dict[str, set[str]]:
47-
"""Get dictionary of branch names to set of updated profile UUIDs"""
48-
profiles_by_branch = defaultdict(set)
49-
for result in self.get_results():
50-
profile_uuid = result.get_as_type("profile_uuid", str)
51-
branches = result.get_as_type("branches", list[str])
52-
for branch in branches:
53-
profiles_by_branch[branch].add(profile_uuid)
54-
return profiles_by_branch
43+
def get_profile_ids(self) -> list[str]:
44+
"""Get list of updated profile UUIDs"""
45+
return [result.get_as_type("profile_uuid", str) for result in self.get_results()]
5546

5647

57-
class GetNodesWithProfileUpdatesByBranchQuery(Query):
48+
class GetNodesWithProfileUpdatesForBranchQuery(Query):
5849
"""
5950
Get Node UUIDs by which branches they have updated profiles on
6051
"""
6152

6253
name = "get_nodes_with_profile_updates_by_branch"
6354
type = QueryType.READ
64-
insert_return = False
6555

6656
async def query_init(self, db: InfrahubDatabase, **kwargs: dict[str, Any]) -> None: # noqa: ARG002
57+
self.params["branch"] = self.branch.name
6758
query = """
68-
MATCH (node:Node)-[e1:IS_RELATED]->(:Relationship {name: "node__profile"})
59+
MATCH (node:Node)-[e:IS_RELATED]->(:Relationship {name: "node__profile"})
6960
WHERE NOT node:CoreProfile
70-
WITH DISTINCT node.uuid AS node_uuid, e1.branch AS branch
71-
RETURN node_uuid, collect(branch) AS branches
61+
AND e.branch = $branch
62+
WITH DISTINCT node.uuid AS node_uuid
7263
"""
7364
self.add_to_query(query)
74-
self.return_labels = ["node_uuid", "branches"]
65+
self.return_labels = ["node_uuid"]
7566

76-
def get_node_ids_by_branch(self) -> dict[str, set[str]]:
77-
"""Get dictionary of branch names to set of updated node UUIDs"""
78-
nodes_by_branch = defaultdict(set)
79-
for result in self.get_results():
80-
node_uuid = result.get_as_type("node_uuid", str)
81-
branches = result.get_as_type("branches", list[str])
82-
for branch in branches:
83-
nodes_by_branch[branch].add(node_uuid)
84-
return nodes_by_branch
67+
def get_node_ids(self) -> list[str]:
68+
"""Get list of updated node UUIDs"""
69+
return [result.get_as_type("node_uuid", str) for result in self.get_results()]
8570

8671

87-
class Migration041(ArbitraryMigration):
72+
class Migration041(MigrationRequiringRebase):
8873
"""
8974
Save profile attribute values on each node using the profile in the database
9075
For any profile that has updates on a given branch (including default branch)
@@ -96,71 +81,64 @@ class Migration041(ArbitraryMigration):
9681
name: str = "041_profile_attrs_in_db"
9782
minimum_version: int = 40
9883

99-
def __init__(self, *args: Any, **kwargs: Any) -> None:
100-
super().__init__(*args, **kwargs)
101-
self._appliers_by_branch: dict[str, NodeProfilesApplier] = {}
102-
103-
async def _get_profile_applier(self, db: InfrahubDatabase, branch_name: str) -> NodeProfilesApplier:
104-
if branch_name not in self._appliers_by_branch:
105-
branch = await Branch.get_by_name(db=db, name=branch_name)
106-
self._appliers_by_branch[branch_name] = NodeProfilesApplier(db=db, branch=branch)
107-
return self._appliers_by_branch[branch_name]
84+
def _get_profile_applier(self, db: InfrahubDatabase, branch: Branch) -> NodeProfilesApplier:
85+
return NodeProfilesApplier(db=db, branch=branch)
10886

10987
async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002
11088
return MigrationResult()
11189

11290
async def execute(self, db: InfrahubDatabase) -> MigrationResult:
91+
root_node = await get_root_node(db=db, initialize=False)
92+
default_branch_name = root_node.default_branch
93+
default_branch = await Branch.get_by_name(db=db, name=default_branch_name)
94+
return await self._do_execute_for_branch(db=db, branch=default_branch)
95+
96+
async def execute_against_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult:
97+
return await self._do_execute_for_branch(db=db, branch=branch)
98+
99+
async def _do_execute_for_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult:
113100
console = Console()
114101
result = MigrationResult()
115-
# load schemas from database into registry
116-
initialize_lock()
117-
await initialization(db=db)
118-
119-
console.print("Gathering profiles for each branch...", end="")
120-
get_profiles_by_branch_query = await GetProfilesByBranchQuery.init(db=db)
121-
await get_profiles_by_branch_query.execute(db=db)
122-
profiles_ids_by_branch = get_profiles_by_branch_query.get_profile_ids_by_branch()
123-
124-
profiles_by_branch: dict[str, list[Node]] = {}
125-
for branch_name, profile_ids in profiles_ids_by_branch.items():
126-
profiles_map = await NodeManager.get_many(db=db, branch=branch_name, ids=list(profile_ids))
127-
profiles_by_branch[branch_name] = list(profiles_map.values())
102+
await get_or_load_schema_branch(db=db, branch=branch)
103+
104+
console.print(f"Gathering profiles for each branch {branch.name}...", end="")
105+
get_updated_profiles_for_branch_query = await GetUpdatedProfilesForBranchQuery.init(db=db, branch=branch)
106+
await get_updated_profiles_for_branch_query.execute(db=db)
107+
profile_ids = get_updated_profiles_for_branch_query.get_profile_ids()
108+
109+
profiles_map = await NodeManager.get_many(db=db, branch=branch, ids=list(profile_ids))
128110
console.print("done")
129111

130-
node_ids_to_update_by_branch: dict[str, set[str]] = defaultdict(set)
131-
total_size = sum(len(profiles) for profiles in profiles_by_branch.values())
112+
node_ids_to_update: set[str] = set()
132113
with Progress() as progress:
133114
gather_nodes_task = progress.add_task(
134-
"Gathering affected objects for each profile on each branch...", total=total_size
115+
"Gathering affected objects for each profile on branch {branch.name}...", total=len(profiles_map)
135116
)
136117

137-
for branch_name, profiles in profiles_by_branch.items():
138-
for profile in profiles:
139-
node_relationship_manager = profile.get_relationship("related_nodes")
140-
node_peers = await node_relationship_manager.get_db_peers(db=db)
141-
node_ids_to_update_by_branch[branch_name].update({str(peer.peer_id) for peer in node_peers})
142-
progress.update(gather_nodes_task, advance=1)
118+
for profile in profiles_map.values():
119+
node_relationship_manager = profile.get_relationship("related_nodes")
120+
node_peers = await node_relationship_manager.get_db_peers(db=db)
121+
node_ids_to_update.update(str(peer.peer_id) for peer in node_peers)
122+
progress.update(gather_nodes_task, advance=1)
143123

144124
console.print("Identifying nodes with profile updates by branch...", end="")
145-
get_nodes_with_profile_updates_by_branch_query = await GetNodesWithProfileUpdatesByBranchQuery.init(db=db)
125+
get_nodes_with_profile_updates_by_branch_query = await GetNodesWithProfileUpdatesForBranchQuery.init(
126+
db=db, branch=branch
127+
)
146128
await get_nodes_with_profile_updates_by_branch_query.execute(db=db)
147-
nodes_ids_by_branch = get_nodes_with_profile_updates_by_branch_query.get_node_ids_by_branch()
148-
for branch_name, node_ids in nodes_ids_by_branch.items():
149-
node_ids_to_update_by_branch[branch_name].update(node_ids)
129+
node_ids_to_update.update(get_nodes_with_profile_updates_by_branch_query.get_node_ids())
150130
console.print("done")
151131

152132
right_now = Timestamp()
153-
total_size = sum(len(node_ids) for node_ids in node_ids_to_update_by_branch.values())
154133
with Progress() as progress:
155-
apply_task = progress.add_task("Applying profiles to nodes...", total=total_size)
156-
for branch_name, node_ids in node_ids_to_update_by_branch.items():
157-
applier = await self._get_profile_applier(db=db, branch_name=branch_name)
158-
for node_id in node_ids:
159-
node = await NodeManager.get_one(db=db, branch=branch_name, id=node_id, at=right_now)
160-
if node:
161-
updated_field_names = await applier.apply_profiles(node=node)
162-
if updated_field_names:
163-
await node.save(db=db, fields=updated_field_names, at=right_now)
164-
progress.update(apply_task, advance=1)
134+
apply_task = progress.add_task("Applying profiles to nodes...", total=len(node_ids_to_update))
135+
applier = self._get_profile_applier(db=db, branch=branch)
136+
for node_id in node_ids_to_update:
137+
node = await NodeManager.get_one(db=db, branch=branch, id=node_id, at=right_now)
138+
if node:
139+
updated_field_names = await applier.apply_profiles(node=node)
140+
if updated_field_names:
141+
await node.save(db=db, fields=updated_field_names, at=right_now)
142+
progress.update(apply_task, advance=1)
165143

166144
return result

backend/infrahub/core/migrations/graph/m042_create_hfid_display_label_in_db.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,10 @@
1313
from infrahub.core.migrations.shared import InternalSchemaMigration, MigrationResult
1414
from infrahub.core.path import SchemaPath
1515
from infrahub.core.query import Query, QueryType
16-
from infrahub.core.schema import SchemaRoot, internal_schema
17-
from infrahub.core.schema.manager import SchemaManager
18-
from infrahub.exceptions import InitializationError
16+
17+
from .load_schema_branch import get_or_load_schema_branch
1918

2019
if TYPE_CHECKING:
21-
from infrahub.core.schema.schema_branch import SchemaBranch
2220
from infrahub.database import InfrahubDatabase
2321

2422

@@ -76,25 +74,13 @@ def init(cls, **kwargs: Any) -> Self:
7674
]
7775
return cls(migrations=cls.migrations, **kwargs) # type: ignore[arg-type]
7876

79-
async def _get_or_load_schema_branch(self, db: InfrahubDatabase, branch: Branch) -> SchemaBranch:
80-
try:
81-
if registry.schema.has_schema_branch(branch.name):
82-
return registry.schema.get_schema_branch(branch.name)
83-
except InitializationError:
84-
pass
85-
schema_manager = SchemaManager()
86-
internal_schema_root = SchemaRoot(**internal_schema)
87-
schema_manager.register_schema(schema=internal_schema_root)
88-
registry.schema = schema_manager
89-
return await schema_manager.load_schema_from_db(db=db, branch=branch)
90-
9177
async def execute(self, db: InfrahubDatabase) -> MigrationResult:
9278
result = MigrationResult()
9379

9480
root_node = await get_root_node(db=db, initialize=False)
9581
default_branch_name = root_node.default_branch
9682
default_branch = await Branch.get_by_name(db=db, name=default_branch_name)
97-
schema_branch = await self._get_or_load_schema_branch(db=db, branch=default_branch)
83+
schema_branch = await get_or_load_schema_branch(db=db, branch=default_branch)
9884

9985
migrations = list(self.migrations)
10086

backend/infrahub/core/migrations/graph/m043_backfill_hfid_display_label_in_db.py

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
from infrahub.core.initialization import get_root_node
1414
from infrahub.core.migrations.shared import MigrationResult
1515
from infrahub.core.query import Query, QueryType
16-
from infrahub.core.schema import AttributeSchema, NodeSchema, SchemaRoot, internal_schema
17-
from infrahub.core.schema.manager import SchemaManager
18-
from infrahub.exceptions import InitializationError
1916
from infrahub.types import is_large_attribute_type
2017

2118
from ..shared import MigrationRequiringRebase
19+
from .load_schema_branch import get_or_load_schema_branch
2220

2321
if TYPE_CHECKING:
22+
from infrahub.core.schema import AttributeSchema, NodeSchema
2423
from infrahub.core.schema.basenode_schema import SchemaAttributePath
2524
from infrahub.core.schema.schema_branch import SchemaBranch
2625
from infrahub.database import InfrahubDatabase
@@ -620,18 +619,6 @@ class Migration043(MigrationRequiringRebase):
620619
async def validate_migration(self, db: InfrahubDatabase) -> MigrationResult: # noqa: ARG002
621620
return MigrationResult()
622621

623-
async def _get_or_load_schema_branch(self, db: InfrahubDatabase, branch: Branch) -> SchemaBranch:
624-
try:
625-
if registry.schema.has_schema_branch(branch.name):
626-
return registry.schema.get_schema_branch(branch.name)
627-
except InitializationError:
628-
pass
629-
schema_manager = SchemaManager()
630-
internal_schema_root = SchemaRoot(**internal_schema)
631-
schema_manager.register_schema(schema=internal_schema_root)
632-
registry.schema = schema_manager
633-
return await schema_manager.load_schema_from_db(db=db, branch=branch)
634-
635622
async def _do_one_schema_all(
636623
self,
637624
db: InfrahubDatabase,
@@ -722,7 +709,7 @@ async def execute(self, db: InfrahubDatabase) -> MigrationResult:
722709
default_branch_name = root_node.default_branch
723710
default_branch = await Branch.get_by_name(db=db, name=default_branch_name)
724711

725-
main_schema_branch = await self._get_or_load_schema_branch(db=db, branch=default_branch)
712+
main_schema_branch = await get_or_load_schema_branch(db=db, branch=default_branch)
726713

727714
total_nodes_query = await DefaultBranchNodeCount.init(db=db, kinds_to_skip=self.kinds_to_skip)
728715
await total_nodes_query.execute(db=db)
@@ -827,8 +814,8 @@ async def _do_one_schema_branch(
827814

828815
async def execute_against_branch(self, db: InfrahubDatabase, branch: Branch) -> MigrationResult:
829816
default_branch = await Branch.get_by_name(db=db, name=registry.default_branch)
830-
main_schema_branch = await self._get_or_load_schema_branch(db=db, branch=default_branch)
831-
schema_branch = await self._get_or_load_schema_branch(db=db, branch=branch)
817+
main_schema_branch = await get_or_load_schema_branch(db=db, branch=default_branch)
818+
schema_branch = await get_or_load_schema_branch(db=db, branch=branch)
832819

833820
base_node_schema = schema_branch.get("SchemaNode", duplicate=False)
834821
display_label_attribute_schema = base_node_schema.get_attribute("display_label")

0 commit comments

Comments
 (0)