11from __future__ import annotations
22
3+ from enum import Enum
34from typing import TYPE_CHECKING , Self
45
56from graphene import Boolean , InputField , InputObjectType , List , Mutation , String
910from infrahub .core .changelog .models import NodeChangelog
1011from infrahub .core .constants import InfrahubKind , MutationAction , RelationshipCardinality
1112from infrahub .core .manager import NodeManager
13+ from infrahub .core .query .node import NodeGetKindQuery
1214from infrahub .core .query .relationship import (
1315 RelationshipGetPeerQuery ,
1416 RelationshipPeerData ,
3335RELATIONSHIP_PEERS_TO_IGNORE = [InfrahubKind .NODE ]
3436
3537
38+ class GroupUpdateType (str , Enum ):
39+ NONE = "none"
40+ MEMBERS = "members"
41+ MEMBER_OF_GROUPS = "member_of_groups"
42+
43+
3644class RelationshipNodesInput (InputObjectType ):
3745 id = InputField (String (required = True ), description = "ID of the node at the source of the relationship" )
3846 name = InputField (String (required = True ), description = "Name of the relationship to add or remove nodes" )
@@ -65,16 +73,24 @@ async def mutate( # noqa: PLR0915
6573 raise NodeNotFoundError (node_type = "node" , identifier = input_id , branch_name = graphql_context .branch .name )
6674
6775 # Check if the name of the relationship provided exist for this node and is of cardinality Many
68- if relationship_name not in source ._schema .relationship_names :
76+ if relationship_name not in source .get_schema () .relationship_names :
6977 raise ValidationError (
7078 {"name" : f"'{ relationship_name } ' is not a valid relationship for '{ source .get_kind ()} '" }
7179 )
7280
73- rel_schema = source ._schema .get_relationship (name = relationship_name )
81+ rel_schema = source .get_schema () .get_relationship (name = relationship_name )
7482 if rel_schema .cardinality != RelationshipCardinality .MANY :
7583 raise ValidationError ({"name" : f"'{ relationship_name } ' must be a relationship of cardinality Many" })
7684
77- is_group_event = rel_schema .identifier == "group_member" and "CoreGroup" in source ._schema .inherit_from
85+ group_event_type = GroupUpdateType .NONE
86+ if rel_schema .identifier == "group_member" :
87+ if "CoreGroup" in source .get_schema ().inherit_from and relationship_name == "members" :
88+ # Updating members of a group
89+ group_event_type = GroupUpdateType .MEMBERS
90+
91+ elif relationship_name == "member_of_groups" :
92+ # Modifying the membership of the current node
93+ group_event_type = GroupUpdateType .MEMBER_OF_GROUPS
7894
7995 # Query the node in the database and validate that all of them exist and are if the correct kind
8096 node_ids : list [str ] = [node_data ["id" ] for node_data in data .get ("nodes" ) if "id" in node_data ]
@@ -93,7 +109,9 @@ async def mutate( # noqa: PLR0915
93109 if rel_schema .peer not in node .get_labels ():
94110 raise ValidationError (f"{ node_id !r} { node .get_kind ()!r} is not a valid peer for '{ rel_schema .peer } '" )
95111
96- peer_relationships = [rel for rel in node ._schema .relationships if rel .identifier == rel_schema .identifier ]
112+ peer_relationships = [
113+ rel for rel in node .get_schema ().relationships if rel .identifier == rel_schema .identifier
114+ ]
97115 if (
98116 rel_schema .identifier
99117 and len (peer_relationships ) == 1
@@ -120,7 +138,7 @@ async def mutate( # noqa: PLR0915
120138 await query .execute (db = graphql_context .db )
121139 existing_peers : dict [str , RelationshipPeerData ] = {str (peer .peer_id ): peer for peer in query .get_peers ()}
122140 async with graphql_context .db .start_transaction () as db :
123- members = []
141+ peers : list [ EventNode ] = []
124142 if cls .__name__ == "RelationshipAdd" :
125143 for node_data in data .get ("nodes" ):
126144 # Instantiate and resolve a relationship
@@ -130,7 +148,7 @@ async def mutate( # noqa: PLR0915
130148 await rel .resolve (db = db )
131149 # Save it only if it does not exist
132150 if rel .get_peer_id () not in existing_peers .keys ():
133- members .append (EventNode (id = rel .get_peer_id (), kind = rel .get_peer_kind ()))
151+ peers .append (EventNode (id = rel .get_peer_id (), kind = rel .get_peer_kind ()))
134152 node_changelog .create_relationship (relationship = rel )
135153 await rel .save (db = db )
136154
@@ -142,13 +160,13 @@ async def mutate( # noqa: PLR0915
142160 # it would be more query efficient
143161 rel = Relationship (schema = rel_schema , branch = graphql_context .branch , node = source )
144162 await rel .load (db = db , data = existing_peers [node_data .get ("id" )])
145- members .append (EventNode (id = rel .get_peer_id (), kind = rel .get_peer_kind ()))
163+ peers .append (EventNode (id = rel .get_peer_id (), kind = rel .get_peer_kind ()))
146164 node_changelog .delete_relationship (relationship = rel )
147165 await rel .delete (db = db )
148166
149167 if config .SETTINGS .broker .enable and graphql_context .background :
150168 event = NodeMutatedEvent (
151- kind = source ._schema .kind ,
169+ kind = source .get_schema () .kind ,
152170 node_id = source .id ,
153171 data = node_changelog ,
154172 action = MutationAction .UPDATED ,
@@ -158,18 +176,43 @@ async def mutate( # noqa: PLR0915
158176 ),
159177 )
160178 graphql_context .background .add_task (graphql_context .active_service .event .send , event )
161- if is_group_event :
179+ if group_event_type == GroupUpdateType . MEMBERS :
162180 if cls .__name__ == "RelationshipAdd" :
163181 group_add_event = GroupMemberAddedEvent (
164- node_id = source .id , kind = source ._schema .kind , members = members
182+ node_id = source .id , kind = source .get_schema () .kind , members = peers
165183 )
166184 graphql_context .background .add_task (graphql_context .active_service .event .send , group_add_event )
167185 elif cls .__name__ == "RelationshipRemove" :
168186 group_remove_event = GroupMemberRemovedEvent (
169- node_id = source .id , kind = source ._schema .kind , members = members
187+ node_id = source .id , kind = source .get_schema () .kind , members = peers
170188 )
171189 graphql_context .background .add_task (graphql_context .active_service .event .send , group_remove_event )
172-
190+ elif group_event_type == GroupUpdateType .MEMBER_OF_GROUPS :
191+ group_ids = [node .id for node in peers ]
192+ async with graphql_context .db .start_session () as db :
193+ node_kind_query = await NodeGetKindQuery .init (db = db , branch = graphql_context .branch , ids = group_ids )
194+ await node_kind_query .execute (db = db )
195+ node_kind_map = await node_kind_query .get_node_kind_map ()
196+
197+ for node_id , node_kind in node_kind_map .items ():
198+ if cls .__name__ == "RelationshipAdd" :
199+ group_add_event = GroupMemberAddedEvent (
200+ node_id = node_id ,
201+ kind = node_kind ,
202+ members = [EventNode (id = source .get_id (), kind = source .get_kind ())],
203+ )
204+ graphql_context .background .add_task (
205+ graphql_context .active_service .event .send , group_add_event
206+ )
207+ elif cls .__name__ == "RelationshipRemove" :
208+ group_remove_event = GroupMemberRemovedEvent (
209+ node_id = node_id ,
210+ kind = node_kind ,
211+ members = [EventNode (id = source .get_id (), kind = source .get_kind ())],
212+ )
213+ graphql_context .background .add_task (
214+ graphql_context .active_service .event .send , group_remove_event
215+ )
173216 return cls (ok = True ) # type: ignore[call-arg]
174217
175218
0 commit comments