Skip to content

Commit 7c5b97a

Browse files
committed
Add parsing of variables and queries to GraphQL analyzer
1 parent 281aeab commit 7c5b97a

File tree

3 files changed

+386
-21
lines changed

3 files changed

+386
-21
lines changed

backend/infrahub/graphql/analyzer.py

Lines changed: 139 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,27 @@
1313
FragmentSpreadNode,
1414
GraphQLSchema,
1515
InlineFragmentNode,
16+
ListTypeNode,
1617
NamedTypeNode,
1718
NonNullTypeNode,
1819
OperationDefinitionNode,
1920
OperationType,
2021
SelectionSetNode,
22+
TypeNode,
23+
)
24+
from graphql.language.ast import (
25+
BooleanValueNode,
26+
ConstListValueNode,
27+
ConstObjectValueNode,
28+
EnumValueNode,
29+
FloatValueNode,
30+
IntValueNode,
31+
ListValueNode,
32+
NullValueNode,
33+
ObjectValueNode,
34+
StringValueNode,
35+
ValueNode,
36+
VariableNode,
2137
)
2238
from infrahub_sdk.analyzer import GraphQLQueryAnalyzer
2339
from infrahub_sdk.utils import extract_fields
@@ -91,9 +107,24 @@ class GraphQLSelectionSet:
91107
@dataclass
92108
class GraphQLArgument:
93109
name: str
94-
value: str
110+
value: Any
95111
kind: str
96112

113+
@property
114+
def is_variable(self) -> bool:
115+
return self.kind == "variable"
116+
117+
@property
118+
def as_variable_name(self) -> str:
119+
"""Return the name without a $ prefix"""
120+
return str(self.value).removeprefix("$")
121+
122+
@property
123+
def fields(self) -> list[str]:
124+
if self.kind != "object_value" or not isinstance(self.value, dict):
125+
return []
126+
return sorted(self.value.keys())
127+
97128

98129
@dataclass
99130
class ObjectAccess:
@@ -106,6 +137,9 @@ class GraphQLVariable:
106137
name: str
107138
type: str
108139
required: bool
140+
is_list: bool = False
141+
inner_required: bool = False
142+
default: Any | None = None
109143

110144

111145
@dataclass
@@ -266,6 +300,28 @@ def fields_by_kind(self, kind: str) -> list[str]:
266300

267301
return fields
268302

303+
@cached_property
304+
def variables(self) -> list[GraphQLVariable]:
305+
"""Return input variables defined on the query document
306+
307+
All subqueries will use the same document level queries,
308+
so only the first entry is required
309+
"""
310+
if self.queries:
311+
return self.queries[0].variables
312+
return []
313+
314+
def required_argument(self, argument: GraphQLArgument) -> bool:
315+
if not argument.is_variable:
316+
# If the argument isn't a variable it would have been
317+
# statically defined in the input and as such required
318+
return True
319+
for variable in self.variables:
320+
if variable.name == argument.as_variable_name and variable.required:
321+
return True
322+
323+
return False
324+
269325
@cached_property
270326
def top_level_kinds(self) -> list[str]:
271327
return [query.infrahub_model.kind for query in self.queries if query.infrahub_model]
@@ -298,6 +354,22 @@ def kind_action_map(self) -> dict[str, set[MutateAction]]:
298354

299355
return access
300356

357+
@property
358+
def only_has_unique_targets(self) -> bool:
359+
"""Indicate if the query document is defined so that it will return a single root level object"""
360+
for query in self.queries:
361+
targets_single_query = False
362+
if query.infrahub_model and query.infrahub_model.uniqueness_constraints:
363+
for argument in query.arguments:
364+
if [[argument.name]] == query.infrahub_model.uniqueness_constraints:
365+
if self.required_argument(argument=argument):
366+
targets_single_query = True
367+
368+
if not targets_single_query:
369+
return False
370+
371+
return True
372+
301373

302374
class InfrahubGraphQLQueryAnalyzer(GraphQLQueryAnalyzer):
303375
def __init__(
@@ -603,31 +675,80 @@ def _get_selections(selection_set: SelectionSetNode) -> GraphQLSelectionSet:
603675
],
604676
)
605677

606-
@staticmethod
607-
def _get_variables(operation: OperationDefinitionNode) -> list[GraphQLVariable]:
608-
variables = []
609-
for variable in operation.variable_definitions:
610-
if isinstance(variable.type, NamedTypeNode):
611-
variables.append(
612-
GraphQLVariable(name=variable.variable.name.value, type=variable.type.name.value, required=False)
678+
def _get_variables(self, operation: OperationDefinitionNode) -> list[GraphQLVariable]:
679+
variables: list[GraphQLVariable] = []
680+
681+
for variable in operation.variable_definitions or []:
682+
type_node: TypeNode = variable.type
683+
required = False
684+
is_list = False
685+
inner_required = False
686+
687+
if isinstance(type_node, NonNullTypeNode):
688+
required = True
689+
type_node = type_node.type
690+
691+
if isinstance(type_node, ListTypeNode):
692+
is_list = True
693+
inner_type = type_node.type
694+
695+
if isinstance(inner_type, NonNullTypeNode):
696+
inner_required = True
697+
inner_type = inner_type.type
698+
699+
if isinstance(inner_type, NamedTypeNode):
700+
type_name = inner_type.name.value
701+
else:
702+
raise TypeError(f"Unsupported inner type node: {inner_type}")
703+
elif isinstance(type_node, NamedTypeNode):
704+
type_name = type_node.name.value
705+
else:
706+
raise TypeError(f"Unsupported type node: {type_node}")
707+
708+
variables.append(
709+
GraphQLVariable(
710+
name=variable.variable.name.value,
711+
type=type_name,
712+
required=required,
713+
is_list=is_list,
714+
inner_required=inner_required,
715+
default=self._parse_value(variable.default_value) if variable.default_value else None,
613716
)
614-
elif isinstance(variable.type, NonNullTypeNode):
615-
if isinstance(variable.type.type, NamedTypeNode):
616-
variables.append(
617-
GraphQLVariable(
618-
name=variable.variable.name.value, type=variable.type.type.name.value, required=True
619-
)
620-
)
717+
)
621718

622719
return variables
623720

624-
@staticmethod
625-
def _parse_arguments(field_node: FieldNode) -> list[GraphQLArgument]:
721+
def _parse_arguments(self, field_node: FieldNode) -> list[GraphQLArgument]:
626722
return [
627723
GraphQLArgument(
628724
name=argument.name.value,
629-
value=getattr(argument.value, "value", ""),
725+
value=self._parse_value(argument.value),
630726
kind=argument.value.kind,
631727
)
632728
for argument in field_node.arguments
633729
]
730+
731+
def _parse_value(self, node: ValueNode) -> Any:
732+
match node:
733+
case VariableNode():
734+
value: Any = f"${node.name.value}"
735+
case IntValueNode():
736+
value = int(node.value)
737+
case FloatValueNode():
738+
value = float(node.value)
739+
case StringValueNode():
740+
value = node.value
741+
case BooleanValueNode():
742+
value = node.value
743+
case NullValueNode():
744+
value = None
745+
case EnumValueNode():
746+
value = node.value
747+
case ListValueNode() | ConstListValueNode():
748+
value = [self._parse_value(item) for item in node.values]
749+
case ObjectValueNode() | ConstObjectValueNode():
750+
value = {field.name.value: self._parse_value(field.value) for field in node.fields}
751+
case _:
752+
raise TypeError(f"Unsupported value node: {node}")
753+
754+
return value

backend/tests/helpers/schema/tshirt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
label="T-shirt",
1111
default_filter="name__value",
1212
display_labels=["name__value"],
13+
uniqueness_constraints=[["name__value"]],
1314
attributes=[
1415
AttributeSchema(name="name", kind="Text"),
1516
AttributeSchema(

0 commit comments

Comments
 (0)