Skip to content
This repository was archived by the owner on Feb 6, 2025. It is now read-only.

Commit 33cc5e8

Browse files
chewseleneSelene Chew
andauthored
Type hint make query plan (#990)
* Type hint make_query_plan and convert namedtuple to dataclass * add docstring * fix typos * typing copilot tighten * check for arguments Co-authored-by: Selene Chew <[email protected]>
1 parent 302edd7 commit 33cc5e8

File tree

2 files changed

+103
-66
lines changed

2 files changed

+103
-66
lines changed

graphql_compiler/schema_transformation/make_query_plan.py

Lines changed: 103 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright 2019-present Kensho Technologies, LLC.
2-
from collections import namedtuple
32
from copy import copy
3+
from dataclasses import dataclass
4+
from typing import FrozenSet, List, Optional, Tuple, cast
45

56
from graphql import print_ast
67
from graphql.language.ast import (
@@ -15,45 +16,59 @@
1516
SelectionSetNode,
1617
StringValueNode,
1718
)
19+
from graphql.pyutils import FrozenList
1820

1921
from ..ast_manipulation import get_only_query_definition
2022
from ..exceptions import GraphQLValidationError
2123
from ..schema import FilterDirective, OutputDirective
24+
from .split_query import AstType, SubQueryNode
2225

2326

24-
SubQueryPlan = namedtuple(
25-
"SubQueryPlan",
26-
(
27-
"query_ast", # Document, representing a piece of the overall query with directives added
28-
"schema_id", # str, identifying the schema that this query piece targets
29-
"parent_query_plan", # SubQueryPlan, the query that the current query depends on
30-
"child_query_plans", # List[SubQueryPlan], the queries that depend on the current query
31-
),
32-
)
27+
@dataclass
28+
class SubQueryPlan:
29+
"""Query plan for a part of a larger query over a single schema."""
3330

31+
# Representing a piece of the overall query with directives added.
32+
query_ast: DocumentNode
3433

35-
OutputJoinDescriptor = namedtuple(
36-
"OutputJoinDescriptor",
37-
(
38-
"output_names", # Tuple[str, str], (parent output name, child output name)
39-
# May be expanded to have more attributes, e.g. is_optional, describing how the join
40-
# should be made
41-
),
42-
)
34+
# Identifier for the schema that this query piece targets.
35+
schema_id: Optional[str]
4336

37+
# The query that the current query depends on, or None if the current query does not
38+
# depend on another.
39+
parent_query_plan: Optional["SubQueryPlan"]
4440

45-
QueryPlanDescriptor = namedtuple(
46-
"QueryPlanDescriptor",
47-
(
48-
"root_sub_query_plan", # SubQueryPlan
49-
"intermediate_output_names", # frozenset[str], names of outputs to be removed at the end
50-
"output_join_descriptors",
51-
# List[OutputJoinDescriptor], describing which outputs should be joined and how
52-
),
53-
)
41+
# The queries that depend on the current query.
42+
child_query_plans: List["SubQueryPlan"]
43+
44+
45+
@dataclass(frozen=True)
46+
class OutputJoinDescriptor:
47+
"""Description of what outputs should be joined and how."""
48+
49+
# (parent output name, child output name)
50+
# May be expanded to have more attributes, e.g. is_optional, describing how the join
51+
# should be made.
52+
output_names: Tuple[str, str]
53+
54+
55+
@dataclass(frozen=True)
56+
class QueryPlanDescriptor:
57+
"""Describes a query plan including output join information and intermediate output names."""
58+
59+
# The root of the query plan.
60+
root_sub_query_plan: SubQueryPlan
5461

62+
# Names of outputs to be removed at the end.
63+
intermediate_output_names: FrozenSet[str]
5564

56-
def make_query_plan(root_sub_query_node, intermediate_output_names):
65+
# Describing which outputs should be joined and how.
66+
output_join_descriptors: List[OutputJoinDescriptor]
67+
68+
69+
def make_query_plan(
70+
root_sub_query_node: SubQueryNode, intermediate_output_names: FrozenSet[str]
71+
) -> QueryPlanDescriptor:
5772
"""Return a QueryPlanDescriptor, whose query ASTs have @filters added.
5873
5974
For each parent of parent and child SubQueryNodes, a new @filter directive will be added
@@ -65,17 +80,16 @@ def make_query_plan(root_sub_query_node, intermediate_output_names):
6580
ASTs contained in the input node and its children nodes will not be modified.
6681
6782
Args:
68-
root_sub_query_node: SubQueryNode, representing the base of a query split into pieces
69-
that we want to turn into a query plan
70-
intermediate_output_names: frozenset[str], names of outputs to be removed at the end
83+
root_sub_query_node: representing the base of a query split into pieces
84+
that we want to turn into a query plan.
85+
intermediate_output_names: names of outputs to be removed at the end.
7186
7287
Returns:
73-
QueryPlanDescriptor namedtuple, containing a tree of SubQueryPlans that wrap
74-
around each individual query AST, the set of intermediate output names that are
75-
to be removed at the end, and information on which outputs are to be connect to which
76-
in what manner
88+
QueryPlanDescriptor containing a tree of SubQueryPlans that wrap around each individual
89+
query AST, the set of intermediate output names that are to be removed at the end, and
90+
information on which outputs are to be connect to which in what manner.
7791
"""
78-
output_join_descriptors = []
92+
output_join_descriptors: List[OutputJoinDescriptor] = []
7993

8094
root_sub_query_plan = SubQueryPlan(
8195
query_ast=root_sub_query_node.query_ast,
@@ -93,20 +107,23 @@ def make_query_plan(root_sub_query_node, intermediate_output_names):
93107
)
94108

95109

96-
def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descriptors):
110+
def _make_query_plan_recursive(
111+
sub_query_node: SubQueryNode,
112+
sub_query_plan: SubQueryPlan,
113+
output_join_descriptors: List[OutputJoinDescriptor],
114+
) -> None:
97115
"""Recursively copy the structure of sub_query_node onto sub_query_plan.
98116
99117
For each child connection contained in sub_query_node, create a new SubQueryPlan for
100118
the corresponding child SubQueryNode, add appropriate @filter directive to the child AST,
101119
and attach the new SubQueryPlan to the list of children of the input sub-query plan.
102120
103121
Args:
104-
sub_query_node: SubQueryNode, whose descendents are copied over onto sub_query_plan.
105-
It is not modified by this function
122+
sub_query_node: SubQueryNode, whose child_query_connections are copied over onto
123+
sub_query_plan. It is not modified by this function.
106124
sub_query_plan: SubQueryPlan, whose list of child query plans and query AST are
107-
modified
108-
output_join_descriptors: List[OutputJoinDescriptor], describing which outputs should be
109-
joined and how
125+
modified.
126+
output_join_descriptors: describing which outputs should be joined and how.
110127
111128
"""
112129
# Iterate through child connections of query node
@@ -155,22 +172,22 @@ def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descr
155172
)
156173

157174

158-
def _add_filter_at_field_with_output(ast, field_out_name, input_filter_name):
175+
def _add_filter_at_field_with_output(
176+
ast: AstType, field_out_name: str, input_filter_name: str
177+
) -> AstType:
159178
"""Return an AST with @filter added at the field with the specified @output, if found.
160179
161180
Args:
162-
ast: Field, InlineFragment, or OperationDefinition, an AST Node type that occurs in
163-
the selections of a SelectionSet. It is not modified by this function
164-
field_out_name: str, the out_name of an @output directive. This function will create
181+
ast: AST Node type that occurs in the selections of a SelectionSet.
182+
It is not modified by this function.
183+
field_out_name: the out_name of an @output directive. This function will create
165184
a new @filter directive on the field that has an @output directive
166-
with this out_name
167-
input_filter_name: str, the name of the local variable in the new @filter directive
168-
created
185+
with this out_name.
186+
input_filter_name: the name of the local variable in the new @filter directive created.
169187
170188
Returns:
171-
Field, InlineFragment, or OperationDefinition, identical to the input ast except
172-
with an @filter added at the specified field if such a field is found. If no changes
173-
were made, this is the same object as the input
189+
AST node identical to the input AST except with a @filter added at the specified field if
190+
such a field is found. If no changes were made, this is the same object as the input.
174191
"""
175192
if not isinstance(ast, (FieldNode, InlineFragmentNode, OperationDefinitionNode)):
176193
raise AssertionError(
@@ -187,7 +204,7 @@ def _add_filter_at_field_with_output(ast, field_out_name, input_filter_name):
187204
new_directives = list(ast.directives)
188205
new_directives.append(_get_in_collection_filter_directive(input_filter_name))
189206
new_ast = copy(ast)
190-
new_ast.directives = new_directives
207+
new_ast.directives = cast(FrozenList, new_directives)
191208
return new_ast
192209

193210
if ast.selection_set is None: # Nothing to recurse on
@@ -197,10 +214,18 @@ def _add_filter_at_field_with_output(ast, field_out_name, input_filter_name):
197214
made_changes = False
198215
new_selections = []
199216
for selection in ast.selection_set.selections:
217+
# Make sure selection is a FieldNode of InlineFragment and cast to AST type
218+
# to make mypy happy.
219+
if not isinstance(selection, FieldNode) and not isinstance(selection, InlineFragmentNode):
220+
raise AssertionError(
221+
f"Unexpected selection type {type(selection)}. Only FieldNodes and "
222+
"InlineFragmentNodes are expected."
223+
)
224+
ast_type_selection = cast(AstType, selection)
200225
new_selection = _add_filter_at_field_with_output(
201-
selection, field_out_name, input_filter_name
226+
ast_type_selection, field_out_name, input_filter_name
202227
)
203-
if new_selection is not selection: # Changes made somewhere down the line
228+
if new_selection is not ast_type_selection: # Changes made somewhere down the line
204229
if not made_changes:
205230
made_changes = True
206231
else:
@@ -221,18 +246,32 @@ def _add_filter_at_field_with_output(ast, field_out_name, input_filter_name):
221246
return ast
222247

223248

224-
def _is_output_directive_with_name(directive, out_name):
249+
def _is_output_directive_with_name(directive: DirectiveNode, out_name: str) -> bool:
225250
"""Return whether or not the input is an @output directive with the desired out_name."""
226251
if not isinstance(directive, DirectiveNode):
227252
raise AssertionError('Input "{}" is not a directive.'.format(directive))
228-
return (
229-
directive.name.value == OutputDirective.name
230-
and directive.arguments[0].value.value == out_name
231-
)
253+
# Check whether or not this directive is an output directive.
254+
if directive.name.value != OutputDirective.name:
255+
return False
256+
# Ensure the output directive has arguments since @output takes an `out_name`.
257+
if not directive.arguments:
258+
raise AssertionError(
259+
"directive is an OutputDirective, but has no arguments. This should be impossible! "
260+
f"directive: {directive}"
261+
)
262+
# Ensure he output directive argument is a string since output directives must have a
263+
# non-null string `out_name`.
264+
directive_out_name_value_node = directive.arguments[0].value
265+
if not isinstance(directive_out_name_value_node, StringValueNode):
266+
raise AssertionError(
267+
"directive is an OutputDirective, but has a non-string argument. "
268+
f"This should be impossible! directive: {directive}"
269+
)
270+
return directive_out_name_value_node.value == out_name
232271

233272

234-
def _get_in_collection_filter_directive(input_filter_name):
235-
"""Create a @filter directive with in_collecion operation and the desired variable name."""
273+
def _get_in_collection_filter_directive(input_filter_name: str) -> DirectiveNode:
274+
"""Create a @filter directive with in_collection operation and the desired variable name."""
236275
return DirectiveNode(
237276
name=NameNode(value=FilterDirective.name),
238277
arguments=[
@@ -252,7 +291,7 @@ def _get_in_collection_filter_directive(input_filter_name):
252291
)
253292

254293

255-
def print_query_plan(query_plan_descriptor, indentation_depth=4):
294+
def print_query_plan(query_plan_descriptor: QueryPlanDescriptor, indentation_depth: int = 4) -> str:
256295
"""Return a string describing query plan."""
257296
query_plan_strings = [""]
258297
plan_and_depth = _get_plan_and_depth_in_dfs_order(query_plan_descriptor.root_sub_query_plan)
@@ -274,7 +313,7 @@ def print_query_plan(query_plan_descriptor, indentation_depth=4):
274313
return "".join(query_plan_strings)
275314

276315

277-
def _get_plan_and_depth_in_dfs_order(query_plan):
316+
def _get_plan_and_depth_in_dfs_order(query_plan: SubQueryPlan) -> List[Tuple[SubQueryPlan, int]]:
278317
"""Return a list of topologically sorted (query plan, depth) tuples."""
279318

280319
def _get_plan_and_depth_in_dfs_order_helper(query_plan, depth):

mypy.ini

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ check_untyped_defs = False
191191
disallow_untyped_calls = False
192192

193193
[mypy-graphql_compiler.schema_transformation.make_query_plan.*]
194-
check_untyped_defs = False
195194
disallow_untyped_defs = False
196195

197196
[mypy-graphql_compiler.schema_transformation.split_query.*]
@@ -227,7 +226,6 @@ disallow_untyped_defs = False
227226

228227
[mypy-graphql_compiler.tests.schema_transformation_tests.test_make_query_plan.*]
229228
disallow_incomplete_defs = False
230-
disallow_untyped_calls = False
231229
disallow_untyped_defs = False
232230

233231
[mypy-graphql_compiler.tests.schema_transformation_tests.test_merge_schemas.*]

0 commit comments

Comments
 (0)