11from __future__ import annotations
22
3- from typing import TYPE_CHECKING , Any
3+ from typing import TYPE_CHECKING , Any , cast
44
55from graphene import InputObjectType , Mutation
66from typing_extensions import Self
77
8- from infrahub .core .schema import NodeSchema
8+ from infrahub .core .protocols import CoreNodeTriggerAttributeMatch , CoreNodeTriggerRelationshipMatch , CoreNodeTriggerRule
9+ from infrahub .exceptions import SchemaNotFoundError , ValidationError
910from infrahub .log import get_logger
1011
1112from .main import InfrahubMutationMixin , InfrahubMutationOptions
1516
1617 from infrahub .core .branch import Branch
1718 from infrahub .core .node import Node
19+ from infrahub .core .schema import NodeSchema
1820 from infrahub .database import InfrahubDatabase
1921
22+ from ..initialization import GraphqlContext
23+
2024log = get_logger ()
2125
2226
@@ -28,10 +32,6 @@ def __init_subclass_with_meta__(
2832 _meta : Any | None = None ,
2933 ** options : dict [str , Any ],
3034 ) -> None :
31- # Make sure schema is a valid NodeSchema Node Class
32- if not isinstance (schema , NodeSchema ):
33- raise ValueError (f"You need to pass a valid NodeSchema in '{ cls .__name__ } .Meta', received '{ schema } '" )
34-
3535 if not _meta :
3636 _meta = InfrahubMutationOptions (cls )
3737
@@ -45,9 +45,12 @@ async def mutate_create(
4545 info : GraphQLResolveInfo ,
4646 data : InputObjectType ,
4747 branch : Branch ,
48- database : InfrahubDatabase | None = None , # noqa: ARG003
48+ database : InfrahubDatabase | None = None ,
4949 ) -> tuple [Node , Self ]:
50- trigger_rule_definition , result = await super ().mutate_create (info = info , data = data , branch = branch )
50+ graphql_context : GraphqlContext = info .context
51+ db = database or graphql_context .db
52+ _validate_node_kind (data = data , db = db )
53+ trigger_rule_definition , result = await super ().mutate_create (info = info , data = data , branch = branch , database = db )
5154
5255 return trigger_rule_definition , result
5356
@@ -57,9 +60,105 @@ async def mutate_update(
5760 info : GraphQLResolveInfo ,
5861 data : InputObjectType ,
5962 branch : Branch ,
60- database : InfrahubDatabase | None = None , # noqa: ARG003
63+ database : InfrahubDatabase | None = None ,
6164 node : Node | None = None , # noqa: ARG003
6265 ) -> tuple [Node , Self ]:
63- trigger_rule_definition , result = await super ().mutate_update (info = info , data = data , branch = branch )
66+ graphql_context : GraphqlContext = info .context
67+ db = database or graphql_context .db
68+ _validate_node_kind (data = data , db = db )
69+ trigger_rule_definition , result = await super ().mutate_update (info = info , data = data , branch = branch , database = db )
6470
6571 return trigger_rule_definition , result
72+
73+
74+ class InfrahubTriggerRuleMatchMutation (InfrahubMutationMixin , Mutation ):
75+ @classmethod
76+ def __init_subclass_with_meta__ (
77+ cls ,
78+ schema : NodeSchema ,
79+ _meta : Any | None = None ,
80+ ** options : dict [str , Any ],
81+ ) -> None :
82+ if not _meta :
83+ _meta = InfrahubMutationOptions (cls )
84+
85+ _meta .schema = schema
86+
87+ super ().__init_subclass_with_meta__ (_meta = _meta , ** options )
88+
89+ @classmethod
90+ async def mutate_create (
91+ cls ,
92+ info : GraphQLResolveInfo ,
93+ data : InputObjectType ,
94+ branch : Branch ,
95+ database : InfrahubDatabase | None = None , # noqa: ARG003
96+ ) -> tuple [Node , Self ]:
97+ graphql_context : GraphqlContext = info .context
98+
99+ async with graphql_context .db .start_transaction () as dbt :
100+ trigger_match , result = await super ().mutate_create (info = info , data = data , branch = branch , database = dbt )
101+ trigger_match_model = cast (CoreNodeTriggerAttributeMatch | CoreNodeTriggerRelationshipMatch , trigger_match )
102+ node_trigger_rule = await trigger_match_model .trigger .get_peer (db = dbt , raise_on_error = True )
103+ node_trigger_rule_model = cast (CoreNodeTriggerRule , node_trigger_rule )
104+ node_schema = dbt .schema .get_node_schema (name = node_trigger_rule_model .node_kind .value , duplicate = False )
105+ _validate_node_kind_field (data = data , node_schema = node_schema )
106+
107+ return trigger_match , result
108+
109+ @classmethod
110+ async def mutate_update (
111+ cls ,
112+ info : GraphQLResolveInfo ,
113+ data : InputObjectType ,
114+ branch : Branch ,
115+ database : InfrahubDatabase | None = None , # noqa: ARG003
116+ node : Node | None = None , # noqa: ARG003
117+ ) -> tuple [Node , Self ]:
118+ graphql_context : GraphqlContext = info .context
119+ async with graphql_context .db .start_transaction () as dbt :
120+ trigger_match , result = await super ().mutate_update (info = info , data = data , branch = branch , database = dbt )
121+ trigger_match_model = cast (CoreNodeTriggerAttributeMatch | CoreNodeTriggerRelationshipMatch , trigger_match )
122+ node_trigger_rule = await trigger_match_model .trigger .get_peer (db = dbt , raise_on_error = True )
123+ node_trigger_rule_model = cast (CoreNodeTriggerRule , node_trigger_rule )
124+ node_schema = dbt .schema .get_node_schema (name = node_trigger_rule_model .node_kind .value , duplicate = False )
125+ _validate_node_kind_field (data = data , node_schema = node_schema )
126+
127+ return trigger_match , result
128+
129+
130+ def _validate_node_kind (data : InputObjectType , db : InfrahubDatabase ) -> None :
131+ input_data = cast (dict [str , dict [str , Any ]], data )
132+ if node_kind := input_data .get ("node_kind" ):
133+ value = node_kind .get ("value" )
134+ if isinstance (value , str ):
135+ try :
136+ db .schema .get_node_schema (name = value , duplicate = False )
137+ except SchemaNotFoundError as exc :
138+ raise ValidationError (
139+ input_value = {"node_kind" : "The requested node_kind schema was not found" }
140+ ) from exc
141+ except ValueError as exc :
142+ raise ValidationError (input_value = {"node_kind" : "The requested node_kind is not a valid node" }) from exc
143+
144+
145+ def _validate_node_kind_field (data : InputObjectType , node_schema : NodeSchema ) -> None :
146+ input_data = cast (dict [str , dict [str , Any ]], data )
147+ if attribute_name := input_data .get ("attribute_name" ):
148+ value = attribute_name .get ("value" )
149+ if isinstance (value , str ):
150+ if value not in node_schema .attribute_names :
151+ raise ValidationError (
152+ input_value = {
153+ "attribute_name" : f"The attribute { value } doesn't exist on related node trigger using { node_schema .kind } "
154+ }
155+ )
156+ if relationship_name := input_data .get ("relationship_name" ):
157+ value = relationship_name .get ("value" )
158+ if isinstance (value , str ):
159+ if value not in node_schema .relationship_names :
160+ raise ValidationError (
161+ input_value = {
162+ "relationship_name" : f"The relationship { value } doesn't exist on related node trigger using { node_schema .kind } "
163+ }
164+ )
0 commit comments