1
1
# Copyright 2019-present Kensho Technologies, LLC.
2
- from collections import namedtuple
3
2
from copy import copy
3
+ from dataclasses import dataclass
4
+ from typing import FrozenSet , List , Optional , Tuple , cast
4
5
5
6
from graphql import print_ast
6
7
from graphql .language .ast import (
15
16
SelectionSetNode ,
16
17
StringValueNode ,
17
18
)
19
+ from graphql .pyutils import FrozenList
18
20
19
21
from ..ast_manipulation import get_only_query_definition
20
22
from ..exceptions import GraphQLValidationError
21
23
from ..schema import FilterDirective , OutputDirective
24
+ from .split_query import AstType , SubQueryNode
22
25
23
26
24
- SubQueryPlan = namedtuple (
25
- "SubQueryPlan" ,
26
- (
27
- "query_ast" , # Document, representing a piece of the overall query with directives added
28
- "schema_id" , # str, identifying the schema that this query piece targets
29
- "parent_query_plan" , # SubQueryPlan, the query that the current query depends on
30
- "child_query_plans" , # List[SubQueryPlan], the queries that depend on the current query
31
- ),
32
- )
27
+ @dataclass
28
+ class SubQueryPlan :
29
+ """Query plan for a part of a larger query over a single schema."""
33
30
31
+ # Representing a piece of the overall query with directives added.
32
+ query_ast : DocumentNode
34
33
35
- OutputJoinDescriptor = namedtuple (
36
- "OutputJoinDescriptor" ,
37
- (
38
- "output_names" , # Tuple[str, str], (parent output name, child output name)
39
- # May be expanded to have more attributes, e.g. is_optional, describing how the join
40
- # should be made
41
- ),
42
- )
34
+ # Identifier for the schema that this query piece targets.
35
+ schema_id : Optional [str ]
43
36
37
+ # The query that the current query depends on, or None if the current query does not
38
+ # depend on another.
39
+ parent_query_plan : Optional ["SubQueryPlan" ]
44
40
45
- QueryPlanDescriptor = namedtuple (
46
- "QueryPlanDescriptor" ,
47
- (
48
- "root_sub_query_plan" , # SubQueryPlan
49
- "intermediate_output_names" , # frozenset[str], names of outputs to be removed at the end
50
- "output_join_descriptors" ,
51
- # List[OutputJoinDescriptor], describing which outputs should be joined and how
52
- ),
53
- )
41
+ # The queries that depend on the current query.
42
+ child_query_plans : List ["SubQueryPlan" ]
43
+
44
+
45
+ @dataclass (frozen = True )
46
+ class OutputJoinDescriptor :
47
+ """Description of what outputs should be joined and how."""
48
+
49
+ # (parent output name, child output name)
50
+ # May be expanded to have more attributes, e.g. is_optional, describing how the join
51
+ # should be made.
52
+ output_names : Tuple [str , str ]
53
+
54
+
55
+ @dataclass (frozen = True )
56
+ class QueryPlanDescriptor :
57
+ """Describes a query plan including output join information and intermediate output names."""
58
+
59
+ # The root of the query plan.
60
+ root_sub_query_plan : SubQueryPlan
54
61
62
+ # Names of outputs to be removed at the end.
63
+ intermediate_output_names : FrozenSet [str ]
55
64
56
- def make_query_plan (root_sub_query_node , intermediate_output_names ):
65
+ # Describing which outputs should be joined and how.
66
+ output_join_descriptors : List [OutputJoinDescriptor ]
67
+
68
+
69
+ def make_query_plan (
70
+ root_sub_query_node : SubQueryNode , intermediate_output_names : FrozenSet [str ]
71
+ ) -> QueryPlanDescriptor :
57
72
"""Return a QueryPlanDescriptor, whose query ASTs have @filters added.
58
73
59
74
For each parent of parent and child SubQueryNodes, a new @filter directive will be added
@@ -65,17 +80,16 @@ def make_query_plan(root_sub_query_node, intermediate_output_names):
65
80
ASTs contained in the input node and its children nodes will not be modified.
66
81
67
82
Args:
68
- root_sub_query_node: SubQueryNode, representing the base of a query split into pieces
69
- that we want to turn into a query plan
70
- intermediate_output_names: frozenset[str], names of outputs to be removed at the end
83
+ root_sub_query_node: representing the base of a query split into pieces
84
+ that we want to turn into a query plan.
85
+ intermediate_output_names: names of outputs to be removed at the end.
71
86
72
87
Returns:
73
- QueryPlanDescriptor namedtuple, containing a tree of SubQueryPlans that wrap
74
- around each individual query AST, the set of intermediate output names that are
75
- to be removed at the end, and information on which outputs are to be connect to which
76
- in what manner
88
+ QueryPlanDescriptor containing a tree of SubQueryPlans that wrap around each individual
89
+ query AST, the set of intermediate output names that are to be removed at the end, and
90
+ information on which outputs are to be connect to which in what manner.
77
91
"""
78
- output_join_descriptors = []
92
+ output_join_descriptors : List [ OutputJoinDescriptor ] = []
79
93
80
94
root_sub_query_plan = SubQueryPlan (
81
95
query_ast = root_sub_query_node .query_ast ,
@@ -93,20 +107,23 @@ def make_query_plan(root_sub_query_node, intermediate_output_names):
93
107
)
94
108
95
109
96
- def _make_query_plan_recursive (sub_query_node , sub_query_plan , output_join_descriptors ):
110
+ def _make_query_plan_recursive (
111
+ sub_query_node : SubQueryNode ,
112
+ sub_query_plan : SubQueryPlan ,
113
+ output_join_descriptors : List [OutputJoinDescriptor ],
114
+ ) -> None :
97
115
"""Recursively copy the structure of sub_query_node onto sub_query_plan.
98
116
99
117
For each child connection contained in sub_query_node, create a new SubQueryPlan for
100
118
the corresponding child SubQueryNode, add appropriate @filter directive to the child AST,
101
119
and attach the new SubQueryPlan to the list of children of the input sub-query plan.
102
120
103
121
Args:
104
- sub_query_node: SubQueryNode, whose descendents are copied over onto sub_query_plan.
105
- It is not modified by this function
122
+ sub_query_node: SubQueryNode, whose child_query_connections are copied over onto
123
+ sub_query_plan. It is not modified by this function.
106
124
sub_query_plan: SubQueryPlan, whose list of child query plans and query AST are
107
- modified
108
- output_join_descriptors: List[OutputJoinDescriptor], describing which outputs should be
109
- joined and how
125
+ modified.
126
+ output_join_descriptors: describing which outputs should be joined and how.
110
127
111
128
"""
112
129
# Iterate through child connections of query node
@@ -155,22 +172,22 @@ def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descr
155
172
)
156
173
157
174
158
- def _add_filter_at_field_with_output (ast , field_out_name , input_filter_name ):
175
+ def _add_filter_at_field_with_output (
176
+ ast : AstType , field_out_name : str , input_filter_name : str
177
+ ) -> AstType :
159
178
"""Return an AST with @filter added at the field with the specified @output, if found.
160
179
161
180
Args:
162
- ast: Field, InlineFragment, or OperationDefinition, an AST Node type that occurs in
163
- the selections of a SelectionSet. It is not modified by this function
164
- field_out_name: str, the out_name of an @output directive. This function will create
181
+ ast: AST Node type that occurs in the selections of a SelectionSet.
182
+ It is not modified by this function.
183
+ field_out_name: the out_name of an @output directive. This function will create
165
184
a new @filter directive on the field that has an @output directive
166
- with this out_name
167
- input_filter_name: str, the name of the local variable in the new @filter directive
168
- created
185
+ with this out_name.
186
+ input_filter_name: the name of the local variable in the new @filter directive created.
169
187
170
188
Returns:
171
- Field, InlineFragment, or OperationDefinition, identical to the input ast except
172
- with an @filter added at the specified field if such a field is found. If no changes
173
- were made, this is the same object as the input
189
+ AST node identical to the input AST except with a @filter added at the specified field if
190
+ such a field is found. If no changes were made, this is the same object as the input.
174
191
"""
175
192
if not isinstance (ast , (FieldNode , InlineFragmentNode , OperationDefinitionNode )):
176
193
raise AssertionError (
@@ -187,7 +204,7 @@ def _add_filter_at_field_with_output(ast, field_out_name, input_filter_name):
187
204
new_directives = list (ast .directives )
188
205
new_directives .append (_get_in_collection_filter_directive (input_filter_name ))
189
206
new_ast = copy (ast )
190
- new_ast .directives = new_directives
207
+ new_ast .directives = cast ( FrozenList , new_directives )
191
208
return new_ast
192
209
193
210
if ast .selection_set is None : # Nothing to recurse on
@@ -197,10 +214,18 @@ def _add_filter_at_field_with_output(ast, field_out_name, input_filter_name):
197
214
made_changes = False
198
215
new_selections = []
199
216
for selection in ast .selection_set .selections :
217
+ # Make sure selection is a FieldNode of InlineFragment and cast to AST type
218
+ # to make mypy happy.
219
+ if not isinstance (selection , FieldNode ) and not isinstance (selection , InlineFragmentNode ):
220
+ raise AssertionError (
221
+ f"Unexpected selection type { type (selection )} . Only FieldNodes and "
222
+ "InlineFragmentNodes are expected."
223
+ )
224
+ ast_type_selection = cast (AstType , selection )
200
225
new_selection = _add_filter_at_field_with_output (
201
- selection , field_out_name , input_filter_name
226
+ ast_type_selection , field_out_name , input_filter_name
202
227
)
203
- if new_selection is not selection : # Changes made somewhere down the line
228
+ if new_selection is not ast_type_selection : # Changes made somewhere down the line
204
229
if not made_changes :
205
230
made_changes = True
206
231
else :
@@ -221,18 +246,32 @@ def _add_filter_at_field_with_output(ast, field_out_name, input_filter_name):
221
246
return ast
222
247
223
248
224
- def _is_output_directive_with_name (directive , out_name ) :
249
+ def _is_output_directive_with_name (directive : DirectiveNode , out_name : str ) -> bool :
225
250
"""Return whether or not the input is an @output directive with the desired out_name."""
226
251
if not isinstance (directive , DirectiveNode ):
227
252
raise AssertionError ('Input "{}" is not a directive.' .format (directive ))
228
- return (
229
- directive .name .value == OutputDirective .name
230
- and directive .arguments [0 ].value .value == out_name
231
- )
253
+ # Check whether or not this directive is an output directive.
254
+ if directive .name .value != OutputDirective .name :
255
+ return False
256
+ # Ensure the output directive has arguments since @output takes an `out_name`.
257
+ if not directive .arguments :
258
+ raise AssertionError (
259
+ "directive is an OutputDirective, but has no arguments. This should be impossible! "
260
+ f"directive: { directive } "
261
+ )
262
+ # Ensure he output directive argument is a string since output directives must have a
263
+ # non-null string `out_name`.
264
+ directive_out_name_value_node = directive .arguments [0 ].value
265
+ if not isinstance (directive_out_name_value_node , StringValueNode ):
266
+ raise AssertionError (
267
+ "directive is an OutputDirective, but has a non-string argument. "
268
+ f"This should be impossible! directive: { directive } "
269
+ )
270
+ return directive_out_name_value_node .value == out_name
232
271
233
272
234
- def _get_in_collection_filter_directive (input_filter_name ) :
235
- """Create a @filter directive with in_collecion operation and the desired variable name."""
273
+ def _get_in_collection_filter_directive (input_filter_name : str ) -> DirectiveNode :
274
+ """Create a @filter directive with in_collection operation and the desired variable name."""
236
275
return DirectiveNode (
237
276
name = NameNode (value = FilterDirective .name ),
238
277
arguments = [
@@ -252,7 +291,7 @@ def _get_in_collection_filter_directive(input_filter_name):
252
291
)
253
292
254
293
255
- def print_query_plan (query_plan_descriptor , indentation_depth = 4 ) :
294
+ def print_query_plan (query_plan_descriptor : QueryPlanDescriptor , indentation_depth : int = 4 ) -> str :
256
295
"""Return a string describing query plan."""
257
296
query_plan_strings = ["" ]
258
297
plan_and_depth = _get_plan_and_depth_in_dfs_order (query_plan_descriptor .root_sub_query_plan )
@@ -274,7 +313,7 @@ def print_query_plan(query_plan_descriptor, indentation_depth=4):
274
313
return "" .join (query_plan_strings )
275
314
276
315
277
- def _get_plan_and_depth_in_dfs_order (query_plan ) :
316
+ def _get_plan_and_depth_in_dfs_order (query_plan : SubQueryPlan ) -> List [ Tuple [ SubQueryPlan , int ]] :
278
317
"""Return a list of topologically sorted (query plan, depth) tuples."""
279
318
280
319
def _get_plan_and_depth_in_dfs_order_helper (query_plan , depth ):
0 commit comments