11from __future__ import annotations
22
3- from collections import defaultdict
43from typing import TYPE_CHECKING , Any
54
65from rich .console import Console
76from rich .progress import Progress
87
98from infrahub .core .branch .models import Branch
10- from infrahub .core .initialization import initialization
9+ from infrahub .core .initialization import get_root_node
1110from infrahub .core .manager import NodeManager
1211from infrahub .core .migrations .shared import MigrationResult
1312from infrahub .core .query import Query , QueryType
1413from infrahub .core .timestamp import Timestamp
15- from infrahub .lock import initialize_lock
1614from infrahub .log import get_logger
1715from infrahub .profiles .node_applier import NodeProfilesApplier
1816
19- from ..shared import ArbitraryMigration
17+ from ..shared import MigrationRequiringRebase
18+ from .load_schema_branch import get_or_load_schema_branch
2019
2120if TYPE_CHECKING :
22- from infrahub .core .node import Node
2321 from infrahub .database import InfrahubDatabase
2422
2523log = get_logger ()
2624
2725
28- class GetProfilesByBranchQuery (Query ):
26+ class GetUpdatedProfilesForBranchQuery (Query ):
2927 """
30- Get CoreProfile UUIDs by which branches they have attribute updates on
28+ Get CoreProfile UUIDs with updated attributes on this branch
3129 """
3230
3331 name = "get_profiles_by_branch"
3432 type = QueryType .READ
35- insert_return = False
3633
3734 async def query_init (self , db : InfrahubDatabase , ** kwargs : dict [str , Any ]) -> None : # noqa: ARG002
35+ self .params ["branch" ] = self .branch .name
3836 query = """
3937MATCH (profile:CoreProfile)-[:HAS_ATTRIBUTE]->(attr:Attribute)-[e:HAS_VALUE]->(:AttributeValue)
40- WITH DISTINCT profile.uuid AS profile_uuid, e.branch AS branch
41- RETURN profile_uuid, collect(branch) AS branches
38+ WHERE e.branch = $branch
4239 """
4340 self .add_to_query (query )
44- self .return_labels = ["profile_uuid" , "branches " ]
41+ self .return_labels = ["profile.uuid AS profile_uuid " ]
4542
46- def get_profile_ids_by_branch (self ) -> dict [str , set [str ]]:
47- """Get dictionary of branch names to set of updated profile UUIDs"""
48- profiles_by_branch = defaultdict (set )
49- for result in self .get_results ():
50- profile_uuid = result .get_as_type ("profile_uuid" , str )
51- branches = result .get_as_type ("branches" , list [str ])
52- for branch in branches :
53- profiles_by_branch [branch ].add (profile_uuid )
54- return profiles_by_branch
43+ def get_profile_ids (self ) -> list [str ]:
44+ """Get list of updated profile UUIDs"""
45+ return [result .get_as_type ("profile_uuid" , str ) for result in self .get_results ()]
5546
5647
57- class GetNodesWithProfileUpdatesByBranchQuery (Query ):
48+ class GetNodesWithProfileUpdatesForBranchQuery (Query ):
5849 """
5950 Get Node UUIDs by which branches they have updated profiles on
6051 """
6152
6253 name = "get_nodes_with_profile_updates_by_branch"
6354 type = QueryType .READ
64- insert_return = False
6555
6656 async def query_init (self , db : InfrahubDatabase , ** kwargs : dict [str , Any ]) -> None : # noqa: ARG002
57+ self .params ["branch" ] = self .branch .name
6758 query = """
68- MATCH (node:Node)-[e1 :IS_RELATED]->(:Relationship {name: "node__profile"})
59+ MATCH (node:Node)-[e :IS_RELATED]->(:Relationship {name: "node__profile"})
6960WHERE NOT node:CoreProfile
70- WITH DISTINCT node.uuid AS node_uuid, e1. branch AS branch
71- RETURN node_uuid, collect(branch) AS branches
61+ AND e. branch = $ branch
62+ WITH DISTINCT node.uuid AS node_uuid
7263 """
7364 self .add_to_query (query )
74- self .return_labels = ["node_uuid" , "branches" ]
65+ self .return_labels = ["node_uuid" ]
7566
76- def get_node_ids_by_branch (self ) -> dict [str , set [str ]]:
77- """Get dictionary of branch names to set of updated node UUIDs"""
78- nodes_by_branch = defaultdict (set )
79- for result in self .get_results ():
80- node_uuid = result .get_as_type ("node_uuid" , str )
81- branches = result .get_as_type ("branches" , list [str ])
82- for branch in branches :
83- nodes_by_branch [branch ].add (node_uuid )
84- return nodes_by_branch
67+ def get_node_ids (self ) -> list [str ]:
68+ """Get list of updated node UUIDs"""
69+ return [result .get_as_type ("node_uuid" , str ) for result in self .get_results ()]
8570
8671
87- class Migration041 (ArbitraryMigration ):
72+ class Migration041 (MigrationRequiringRebase ):
8873 """
8974 Save profile attribute values on each node using the profile in the database
9075 For any profile that has updates on a given branch (including default branch)
@@ -96,71 +81,64 @@ class Migration041(ArbitraryMigration):
9681 name : str = "041_profile_attrs_in_db"
9782 minimum_version : int = 40
9883
99- def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
100- super ().__init__ (* args , ** kwargs )
101- self ._appliers_by_branch : dict [str , NodeProfilesApplier ] = {}
102-
103- async def _get_profile_applier (self , db : InfrahubDatabase , branch_name : str ) -> NodeProfilesApplier :
104- if branch_name not in self ._appliers_by_branch :
105- branch = await Branch .get_by_name (db = db , name = branch_name )
106- self ._appliers_by_branch [branch_name ] = NodeProfilesApplier (db = db , branch = branch )
107- return self ._appliers_by_branch [branch_name ]
84+ def _get_profile_applier (self , db : InfrahubDatabase , branch : Branch ) -> NodeProfilesApplier :
85+ return NodeProfilesApplier (db = db , branch = branch )
10886
10987 async def validate_migration (self , db : InfrahubDatabase ) -> MigrationResult : # noqa: ARG002
11088 return MigrationResult ()
11189
11290 async def execute (self , db : InfrahubDatabase ) -> MigrationResult :
91+ root_node = await get_root_node (db = db , initialize = False )
92+ default_branch_name = root_node .default_branch
93+ default_branch = await Branch .get_by_name (db = db , name = default_branch_name )
94+ return await self ._do_execute_for_branch (db = db , branch = default_branch )
95+
96+ async def execute_against_branch (self , db : InfrahubDatabase , branch : Branch ) -> MigrationResult :
97+ return await self ._do_execute_for_branch (db = db , branch = branch )
98+
99+ async def _do_execute_for_branch (self , db : InfrahubDatabase , branch : Branch ) -> MigrationResult :
113100 console = Console ()
114101 result = MigrationResult ()
115- # load schemas from database into registry
116- initialize_lock ()
117- await initialization (db = db )
118-
119- console .print ("Gathering profiles for each branch..." , end = "" )
120- get_profiles_by_branch_query = await GetProfilesByBranchQuery .init (db = db )
121- await get_profiles_by_branch_query .execute (db = db )
122- profiles_ids_by_branch = get_profiles_by_branch_query .get_profile_ids_by_branch ()
123-
124- profiles_by_branch : dict [str , list [Node ]] = {}
125- for branch_name , profile_ids in profiles_ids_by_branch .items ():
126- profiles_map = await NodeManager .get_many (db = db , branch = branch_name , ids = list (profile_ids ))
127- profiles_by_branch [branch_name ] = list (profiles_map .values ())
102+ await get_or_load_schema_branch (db = db , branch = branch )
103+
104+ console .print (f"Gathering profiles for each branch { branch .name } ..." , end = "" )
105+ get_updated_profiles_for_branch_query = await GetUpdatedProfilesForBranchQuery .init (db = db , branch = branch )
106+ await get_updated_profiles_for_branch_query .execute (db = db )
107+ profile_ids = get_updated_profiles_for_branch_query .get_profile_ids ()
108+
109+ profiles_map = await NodeManager .get_many (db = db , branch = branch , ids = list (profile_ids ))
128110 console .print ("done" )
129111
130- node_ids_to_update_by_branch : dict [str , set [str ]] = defaultdict (set )
131- total_size = sum (len (profiles ) for profiles in profiles_by_branch .values ())
112+ node_ids_to_update : set [str ] = set ()
132113 with Progress () as progress :
133114 gather_nodes_task = progress .add_task (
134- "Gathering affected objects for each profile on each branch... " , total = total_size
115+ "Gathering affected objects for each profile on branch { branch.name}... " , total = len ( profiles_map )
135116 )
136117
137- for branch_name , profiles in profiles_by_branch .items ():
138- for profile in profiles :
139- node_relationship_manager = profile .get_relationship ("related_nodes" )
140- node_peers = await node_relationship_manager .get_db_peers (db = db )
141- node_ids_to_update_by_branch [branch_name ].update ({str (peer .peer_id ) for peer in node_peers })
142- progress .update (gather_nodes_task , advance = 1 )
118+ for profile in profiles_map .values ():
119+ node_relationship_manager = profile .get_relationship ("related_nodes" )
120+ node_peers = await node_relationship_manager .get_db_peers (db = db )
121+ node_ids_to_update .update (str (peer .peer_id ) for peer in node_peers )
122+ progress .update (gather_nodes_task , advance = 1 )
143123
144124 console .print ("Identifying nodes with profile updates by branch..." , end = "" )
145- get_nodes_with_profile_updates_by_branch_query = await GetNodesWithProfileUpdatesByBranchQuery .init (db = db )
125+ get_nodes_with_profile_updates_by_branch_query = await GetNodesWithProfileUpdatesForBranchQuery .init (
126+ db = db , branch = branch
127+ )
146128 await get_nodes_with_profile_updates_by_branch_query .execute (db = db )
147- nodes_ids_by_branch = get_nodes_with_profile_updates_by_branch_query .get_node_ids_by_branch ()
148- for branch_name , node_ids in nodes_ids_by_branch .items ():
149- node_ids_to_update_by_branch [branch_name ].update (node_ids )
129+ node_ids_to_update .update (get_nodes_with_profile_updates_by_branch_query .get_node_ids ())
150130 console .print ("done" )
151131
152132 right_now = Timestamp ()
153- total_size = sum (len (node_ids ) for node_ids in node_ids_to_update_by_branch .values ())
154133 with Progress () as progress :
155- apply_task = progress .add_task ("Applying profiles to nodes..." , total = total_size )
156- for branch_name , node_ids in node_ids_to_update_by_branch .items ():
157- applier = await self ._get_profile_applier (db = db , branch_name = branch_name )
158- for node_id in node_ids :
159- node = await NodeManager .get_one (db = db , branch = branch_name , id = node_id , at = right_now )
160- if node :
161- updated_field_names = await applier .apply_profiles (node = node )
162- if updated_field_names :
163- await node .save (db = db , fields = updated_field_names , at = right_now )
164- progress .update (apply_task , advance = 1 )
134+ apply_task = progress .add_task ("Applying profiles to nodes..." , total = len (node_ids_to_update ))
135+ applier = self ._get_profile_applier (db = db , branch = branch )
136+ for node_id in node_ids_to_update :
137+ node = await NodeManager .get_one (db = db , branch = branch , id = node_id , at = right_now )
138+ if node :
139+ updated_field_names = await applier .apply_profiles (node = node )
140+ if updated_field_names :
141+ await node .save (db = db , fields = updated_field_names , at = right_now )
142+ progress .update (apply_task , advance = 1 )
165143
166144 return result
0 commit comments