|
1 | 1 | import logging
|
2 | 2 | from abc import ABC
|
3 |
| -from typing import Dict, Iterable, List, Optional, Tuple, Union |
| 3 | +from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple, Union, cast |
4 | 4 |
|
5 | 5 | from graphql import (
|
6 | 6 | ArgumentNode,
|
7 | 7 | DocumentNode,
|
8 | 8 | FieldNode,
|
9 | 9 | GraphQLArgument,
|
10 | 10 | GraphQLField,
|
| 11 | + GraphQLInputObjectType, |
| 12 | + GraphQLInputType, |
11 | 13 | GraphQLInterfaceType,
|
| 14 | + GraphQLList, |
12 | 15 | GraphQLNamedType,
|
| 16 | + GraphQLNonNull, |
13 | 17 | GraphQLObjectType,
|
14 | 18 | GraphQLSchema,
|
| 19 | + GraphQLWrappingType, |
| 20 | + ListTypeNode, |
| 21 | + ListValueNode, |
| 22 | + NamedTypeNode, |
15 | 23 | NameNode,
|
| 24 | + NonNullTypeNode, |
| 25 | + NullValueNode, |
| 26 | + ObjectFieldNode, |
| 27 | + ObjectValueNode, |
16 | 28 | OperationDefinitionNode,
|
17 | 29 | OperationType,
|
18 | 30 | SelectionSetNode,
|
19 |
| - ast_from_value, |
| 31 | + TypeNode, |
| 32 | + Undefined, |
| 33 | + ValueNode, |
| 34 | + VariableDefinitionNode, |
| 35 | + VariableNode, |
| 36 | + assert_named_type, |
| 37 | + is_input_object_type, |
| 38 | + is_list_type, |
| 39 | + is_non_null_type, |
| 40 | + is_wrapping_type, |
20 | 41 | print_ast,
|
21 | 42 | )
|
22 | 43 | from graphql.pyutils import FrozenList
|
| 44 | +from graphql.utilities import ast_from_value as default_ast_from_value |
23 | 45 |
|
24 | 46 | from .utils import to_camel_case
|
25 | 47 |
|
26 | 48 | log = logging.getLogger(__name__)
|
27 | 49 |
|
28 | 50 |
|
| 51 | +def ast_from_value(value: Any, type_: GraphQLInputType) -> Optional[ValueNode]: |
| 52 | + """ |
| 53 | + This is a partial copy paste of the ast_from_value function in |
| 54 | + graphql-core utilities/ast_from_value.py |
| 55 | +
|
| 56 | + Overwrite the if blocks that use recursion and add a new case to return a |
| 57 | + VariableNode when value is a DSLVariable |
| 58 | +
|
| 59 | + Produce a GraphQL Value AST given a Python object. |
| 60 | + """ |
| 61 | + if isinstance(value, DSLVariable): |
| 62 | + return value.set_type(type_).ast_variable |
| 63 | + |
| 64 | + if is_non_null_type(type_): |
| 65 | + type_ = cast(GraphQLNonNull, type_) |
| 66 | + ast_value = ast_from_value(value, type_.of_type) |
| 67 | + if isinstance(ast_value, NullValueNode): |
| 68 | + return None |
| 69 | + return ast_value |
| 70 | + |
| 71 | + # only explicit None, not Undefined or NaN |
| 72 | + if value is None: |
| 73 | + return NullValueNode() |
| 74 | + |
| 75 | + # undefined |
| 76 | + if value is Undefined: |
| 77 | + return None |
| 78 | + |
| 79 | + # Convert Python list to GraphQL list. If the GraphQLType is a list, but the value |
| 80 | + # is not a list, convert the value using the list's item type. |
| 81 | + if is_list_type(type_): |
| 82 | + type_ = cast(GraphQLList, type_) |
| 83 | + item_type = type_.of_type |
| 84 | + if isinstance(value, Iterable) and not isinstance(value, str): |
| 85 | + maybe_value_nodes = (ast_from_value(item, item_type) for item in value) |
| 86 | + value_nodes = filter(None, maybe_value_nodes) |
| 87 | + return ListValueNode(values=FrozenList(value_nodes)) |
| 88 | + return ast_from_value(value, item_type) |
| 89 | + |
| 90 | + # Populate the fields of the input object by creating ASTs from each value in the |
| 91 | + # Python dict according to the fields in the input type. |
| 92 | + if is_input_object_type(type_): |
| 93 | + if value is None or not isinstance(value, Mapping): |
| 94 | + return None |
| 95 | + type_ = cast(GraphQLInputObjectType, type_) |
| 96 | + field_items = ( |
| 97 | + (field_name, ast_from_value(value[field_name], field.type)) |
| 98 | + for field_name, field in type_.fields.items() |
| 99 | + if field_name in value |
| 100 | + ) |
| 101 | + field_nodes = ( |
| 102 | + ObjectFieldNode(name=NameNode(value=field_name), value=field_value) |
| 103 | + for field_name, field_value in field_items |
| 104 | + if field_value |
| 105 | + ) |
| 106 | + return ObjectValueNode(fields=FrozenList(field_nodes)) |
| 107 | + |
| 108 | + return default_ast_from_value(value, type_) |
| 109 | + |
| 110 | + |
29 | 111 | def dsl_gql(
|
30 | 112 | *operations: "DSLOperation", **operations_with_name: "DSLOperation"
|
31 | 113 | ) -> DocumentNode:
|
@@ -77,6 +159,9 @@ def dsl_gql(
|
77 | 159 | OperationDefinitionNode(
|
78 | 160 | operation=OperationType(operation.operation_type),
|
79 | 161 | selection_set=operation.selection_set,
|
| 162 | + variable_definitions=FrozenList( |
| 163 | + operation.variable_definitions.get_ast_definitions() |
| 164 | + ), |
80 | 165 | **({"name": NameNode(value=operation.name)} if operation.name else {}),
|
81 | 166 | )
|
82 | 167 | for operation in all_operations
|
@@ -156,6 +241,7 @@ def __init__(
|
156 | 241 | """
|
157 | 242 |
|
158 | 243 | self.name: Optional[str] = None
|
| 244 | + self.variable_definitions: DSLVariableDefinitions = DSLVariableDefinitions() |
159 | 245 |
|
160 | 246 | # Concatenate fields without and with alias
|
161 | 247 | all_fields: Tuple["DSLField", ...] = DSLField.get_aliased_fields(
|
@@ -194,6 +280,75 @@ class DSLSubscription(DSLOperation):
|
194 | 280 | operation_type = OperationType.SUBSCRIPTION
|
195 | 281 |
|
196 | 282 |
|
| 283 | +class DSLVariable: |
| 284 | + """The DSLVariable represents a single variable defined in a GraphQL operation |
| 285 | +
|
| 286 | + Instances of this class are generated for you automatically as attributes |
| 287 | + of the :class:`DSLVariableDefinitions` |
| 288 | +
|
| 289 | + The type of the variable is set by the :class:`DSLField` instance that receives it |
| 290 | + in the `args` method. |
| 291 | + """ |
| 292 | + |
| 293 | + def __init__(self, name: str): |
| 294 | + self.type: Optional[TypeNode] = None |
| 295 | + self.name = name |
| 296 | + self.ast_variable = VariableNode(name=NameNode(value=self.name)) |
| 297 | + |
| 298 | + def to_ast_type( |
| 299 | + self, type_: Union[GraphQLWrappingType, GraphQLNamedType] |
| 300 | + ) -> TypeNode: |
| 301 | + if is_wrapping_type(type_): |
| 302 | + if isinstance(type_, GraphQLList): |
| 303 | + return ListTypeNode(type=self.to_ast_type(type_.of_type)) |
| 304 | + elif isinstance(type_, GraphQLNonNull): |
| 305 | + return NonNullTypeNode(type=self.to_ast_type(type_.of_type)) |
| 306 | + |
| 307 | + type_ = assert_named_type(type_) |
| 308 | + return NamedTypeNode(name=NameNode(value=type_.name)) |
| 309 | + |
| 310 | + def set_type( |
| 311 | + self, type_: Union[GraphQLWrappingType, GraphQLNamedType] |
| 312 | + ) -> "DSLVariable": |
| 313 | + self.type = self.to_ast_type(type_) |
| 314 | + return self |
| 315 | + |
| 316 | + |
| 317 | +class DSLVariableDefinitions: |
| 318 | + """The DSLVariableDefinitions represents variable definitions in a GraphQL operation |
| 319 | +
|
| 320 | + Instances of this class have to be created and set as the `variable_definitions` |
| 321 | + attribute of a DSLOperation instance |
| 322 | +
|
| 323 | + Attributes of the DSLVariableDefinitions class are generated automatically |
| 324 | + with the `__getattr__` dunder method in order to generate |
| 325 | + instances of :class:`DSLVariable`, that can then be used as values in the |
| 326 | + `DSLField.args` method |
| 327 | + """ |
| 328 | + |
| 329 | + def __init__(self): |
| 330 | + self.variables: Dict[str, DSLVariable] = {} |
| 331 | + |
| 332 | + def __getattr__(self, name: str) -> "DSLVariable": |
| 333 | + if name not in self.variables: |
| 334 | + self.variables[name] = DSLVariable(name) |
| 335 | + return self.variables[name] |
| 336 | + |
| 337 | + def get_ast_definitions(self) -> List[VariableDefinitionNode]: |
| 338 | + """ |
| 339 | + :meta private: |
| 340 | +
|
| 341 | + Return a list of VariableDefinitionNodes for each variable with a type |
| 342 | + """ |
| 343 | + return [ |
| 344 | + VariableDefinitionNode( |
| 345 | + type=var.type, variable=var.ast_variable, default_value=None, |
| 346 | + ) |
| 347 | + for var in self.variables.values() |
| 348 | + if var.type is not None # only variables used |
| 349 | + ] |
| 350 | + |
| 351 | + |
197 | 352 | class DSLType:
|
198 | 353 | """The DSLType represents a GraphQL type for the DSL code.
|
199 | 354 |
|
|
0 commit comments