From 4cec241c7e168135f3c2646aa67ca9f9e83cd042 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Tue, 11 Aug 2020 12:14:20 -0400 Subject: [PATCH 1/6] Improve type checking in query pagination and AST manipulation code. Includes type-level fixes for handling InlineFragmentNode AST values in query pagination code, a couple of new helper functions, and type hints for the AST manipulation module. --- graphql_compiler/ast_manipulation.py | 95 +++++++++++++++--- .../query_pagination/pagination_planning.py | 16 ++- .../query_pagination/query_parameterizer.py | 97 +++++++++++++------ 3 files changed, 163 insertions(+), 45 deletions(-) diff --git a/graphql_compiler/ast_manipulation.py b/graphql_compiler/ast_manipulation.py index 48d0956df..221f17f22 100644 --- a/graphql_compiler/ast_manipulation.py +++ b/graphql_compiler/ast_manipulation.py @@ -1,41 +1,53 @@ # Copyright 2019-present Kensho Technologies, LLC. +from typing import Any, Optional, Type, Union + from graphql.error import GraphQLSyntaxError from graphql.language.ast import ( DocumentNode, + FieldNode, InlineFragmentNode, ListTypeNode, + NamedTypeNode, + Node, NonNullTypeNode, OperationDefinitionNode, OperationType, + SelectionNode, + TypeNode, ) from graphql.language.parser import parse +from graphql.pyutils import FrozenList from .exceptions import GraphQLParsingError -def get_ast_field_name(ast): +def get_ast_field_name(ast: FieldNode) -> str: """Return the field name for the given AST node.""" return ast.name.value -def get_ast_field_name_or_none(ast): +def get_ast_field_name_or_none(ast: Union[FieldNode, InlineFragmentNode]) -> Optional[str]: """Return the field name for the AST node, or None if the AST is an InlineFragment.""" if isinstance(ast, InlineFragmentNode): return None return get_ast_field_name(ast) -def get_human_friendly_ast_field_name(ast): +def get_human_friendly_ast_field_name(ast: Node) -> str: """Return a human-friendly name for the AST node, suitable for error messages.""" if isinstance(ast, InlineFragmentNode): return "type coercion to {}".format(ast.type_condition) elif isinstance(ast, OperationDefinitionNode): return "{} operation definition".format(ast.operation) - - return get_ast_field_name(ast) + elif isinstance(ast, FieldNode): + return get_ast_field_name(ast) + else: + # Fall back to Node's __repr__() method. + # If we need more information for a specific type, we can add another branch in the if-elif. + return repr(ast) -def safe_parse_graphql(graphql_string): +def safe_parse_graphql(graphql_string: str) -> DocumentNode: """Return an AST representation of the given GraphQL input, reraising GraphQL library errors.""" try: ast = parse(graphql_string) @@ -45,7 +57,9 @@ def safe_parse_graphql(graphql_string): return ast -def get_only_query_definition(document_ast, desired_error_type): +def get_only_query_definition( + document_ast: DocumentNode, desired_error_type: Type[Exception] +) -> OperationDefinitionNode: """Assert that the Document AST contains only a single definition for a query, and return it.""" if not isinstance(document_ast, DocumentNode) or not document_ast.definitions: raise AssertionError( @@ -59,6 +73,12 @@ def get_only_query_definition(document_ast, desired_error_type): ) definition_ast = document_ast.definitions[0] + if not isinstance(definition_ast, OperationDefinitionNode): + raise desired_error_type( + f"Expected a query definition at the start of the GraphQL input, but found an " + f"unsupported and unrecognized definition instead: {definition_ast}" + ) + if definition_ast.operation != OperationType.QUERY: raise desired_error_type( "Expected a GraphQL document with a single query definition, but instead found a " @@ -70,9 +90,12 @@ def get_only_query_definition(document_ast, desired_error_type): return definition_ast -def get_only_selection_from_ast(ast, desired_error_type): +def get_only_selection_from_ast( + ast: Union[FieldNode, InlineFragmentNode, OperationDefinitionNode], + desired_error_type: Type[Exception], +) -> SelectionNode: """Return the selected sub-ast, ensuring that there is precisely one.""" - selections = [] if ast.selection_set is None else ast.selection_set.selections + selections = FrozenList([]) if ast.selection_set is None else ast.selection_set.selections if len(selections) != 1: ast_name = get_human_friendly_ast_field_name(ast) @@ -96,8 +119,10 @@ def get_only_selection_from_ast(ast, desired_error_type): return selections[0] -def get_ast_with_non_null_stripped(ast): +def get_ast_with_non_null_stripped(ast: TypeNode) -> Union[ListTypeNode, NamedTypeNode]: """Strip a NonNullType layer around the AST if there is one, return the underlying AST.""" + result: TypeNode + if isinstance(ast, NonNullTypeNode): stripped_ast = ast.type if isinstance(stripped_ast, NonNullTypeNode): @@ -105,13 +130,57 @@ def get_ast_with_non_null_stripped(ast): "NonNullType is unexpectedly found to wrap around another NonNullType in AST " "{}, which is not allowed.".format(ast) ) - return stripped_ast + result = stripped_ast else: - return ast + result = ast + + if not isinstance(result, (ListTypeNode, NamedTypeNode)): + raise AssertionError( + f"Expected the result to be either a ListTypeNode or a NamedTypeNode, but instead " + f"found: {result}" + ) + return result -def get_ast_with_non_null_and_list_stripped(ast): + +def get_ast_with_non_null_and_list_stripped(ast: TypeNode) -> NamedTypeNode: """Strip any NonNullType or List layers around the AST, return the underlying AST.""" while isinstance(ast, (NonNullTypeNode, ListTypeNode)): ast = ast.type + + if not isinstance(ast, NamedTypeNode): + raise AssertionError( + f"Expected the AST value to be a NamedTypeNode, but unexpectedly instead found: {ast}" + ) + + return ast + + +def assert_selection_is_a_field_node(ast: Any) -> FieldNode: + """Return the input value, if it is indeed a FieldNode, or raise AssertionError otherwise.""" + # N.B.: Using Any as the input type, since the inputs here are a variety of generics, unions, + # and a number of AST types. Since this function asserts the type of the input anyway, + # this is not a concern. + if not isinstance(ast, FieldNode): + raise AssertionError( + f"Expected AST to be a FieldNode, but instead found {type(ast).__name__}. " + f"This is a bug. Node value: {ast}" + ) + + return ast + + +def assert_selection_is_a_field_or_inline_fragment_node( + ast: Any, +) -> Union[FieldNode, InlineFragmentNode]: + """Return the input FieldNode or InlineFragmentNode, or raise AssertionError otherwise.""" + # N.B.: Using Any as the input type, since the inputs here are a variety of generics, unions, + # and a number of AST types. Since this function asserts the type of the input anyway, + # this is not a concern. + if not isinstance(ast, (FieldNode, InlineFragmentNode)): + raise AssertionError( + f"Expected AST to be a FieldNode or InlineFragmentNode, but instead " + f"found {type(ast).__name__}. This is a bug. Node value: {ast}" + ) + return ast diff --git a/graphql_compiler/query_pagination/pagination_planning.py b/graphql_compiler/query_pagination/pagination_planning.py index cd183aa47..0e5ff4b24 100644 --- a/graphql_compiler/query_pagination/pagination_planning.py +++ b/graphql_compiler/query_pagination/pagination_planning.py @@ -3,7 +3,12 @@ from dataclasses import dataclass, field from typing import Tuple -from ..ast_manipulation import get_only_query_definition, get_only_selection_from_ast +from ..ast_manipulation import ( + assert_selection_is_a_field_node, + get_ast_field_name, + get_only_query_definition, + get_only_selection_from_ast, +) from ..cost_estimation.analysis import QueryPlanningAnalysis from ..exceptions import GraphQLError from ..global_utils import PropertyPath @@ -122,8 +127,11 @@ def get_pagination_plan( # TODO(bojanserafimov): Make a better pagination plan. A non-root vertex might have a # higher pagination capacity than the root does. - root_node = get_only_selection_from_ast(definition_ast, GraphQLError).name.value - pagination_node = root_node + root_ast_node = assert_selection_is_a_field_node( + get_only_selection_from_ast(definition_ast, GraphQLError) + ) + root_vertex_name = get_ast_field_name(root_ast_node) + pagination_node = root_vertex_name # TODO(bojanserafimov): Remove pagination fields. The pagination planner is now smart enough # to pick the best field for pagination based on the query. This is not @@ -134,7 +142,7 @@ def get_pagination_plan( return PaginationPlan(tuple()), (PaginationFieldNotSpecified(pagination_node),) # Get the pagination capacity - vertex_path = (root_node,) + vertex_path = (root_vertex_name,) property_path = PropertyPath(vertex_path, pagination_field) capacity = query_analysis.pagination_capacities.get(property_path) # If the pagination capacity is None, then there must be no quantiles for this property. diff --git a/graphql_compiler/query_pagination/query_parameterizer.py b/graphql_compiler/query_pagination/query_parameterizer.py index 5865c7440..3690fd95e 100644 --- a/graphql_compiler/query_pagination/query_parameterizer.py +++ b/graphql_compiler/query_pagination/query_parameterizer.py @@ -1,7 +1,7 @@ # Copyright 2019-present Kensho Technologies, LLC. from copy import copy import logging -from typing import Any, Dict, List, Set, Tuple, cast +from typing import Any, Dict, List, Set, Tuple, TypeVar, Union, cast from graphql import print_ast from graphql.language.ast import ( @@ -19,7 +19,13 @@ ) from graphql.pyutils import FrozenList -from ..ast_manipulation import get_ast_field_name, get_only_query_definition +from ..ast_manipulation import ( + assert_selection_is_a_field_node, + assert_selection_is_a_field_or_inline_fragment_node, + get_ast_field_name, + get_only_query_definition, + get_only_selection_from_ast, +) from ..compiler.helpers import get_parameter_name from ..cost_estimation.analysis import QueryPlanningAnalysis from ..cost_estimation.int_value_conversion import convert_field_value_to_int @@ -113,14 +119,17 @@ def _are_filter_operations_equal_and_possible_to_eliminate( return False +FieldOrFragmentT = TypeVar("FieldOrFragmentT", FieldNode, InlineFragmentNode) + + def _add_pagination_filter_at_node( query_analysis: QueryPlanningAnalysis, node_vertex_path: VertexPath, - node_ast: SelectionNode, + node_ast: FieldOrFragmentT, pagination_field: str, directive_to_add: DirectiveNode, extended_parameters: Dict[str, Any], -) -> Tuple[SelectionNode, Dict[str, Any]]: +) -> Tuple[FieldOrFragmentT, Dict[str, Any]]: """Add the filter to the target field, returning a query and its new parameters. Args: @@ -139,11 +148,6 @@ def _add_pagination_filter_at_node( the same operation removed. new_parameters: The parameters to use with the new_ast """ - if not isinstance(node_ast, (FieldNode, InlineFragmentNode, OperationDefinitionNode)): - raise AssertionError( - f'Input AST is of type "{type(node_ast).__name__}", which should not be a selection.' - ) - if node_ast.selection_set is None: raise AssertionError( f"Vertex field AST has no selection set at path {node_vertex_path}. " @@ -159,15 +163,16 @@ def _add_pagination_filter_at_node( new_selections = [] found_field = False for selection_ast in node_ast.selection_set.selections: - new_selection_ast = selection_ast - field_name = get_ast_field_name(selection_ast) + field_ast = assert_selection_is_a_field_node(selection_ast) + new_field_ast = field_ast + field_name = get_ast_field_name(field_ast) if field_name == pagination_field: found_field = True - new_selection_ast = copy(selection_ast) - new_selection_ast.directives = copy(selection_ast.directives) + new_field_ast = copy(field_ast) + new_field_ast.directives = copy(field_ast.directives) new_directives: List[DirectiveNode] = [] - for directive in selection_ast.directives or []: + for directive in field_ast.directives or []: operation = _get_filter_node_operation(directive) if _are_filter_operations_equal_and_possible_to_eliminate( new_directive_operation, operation @@ -200,8 +205,8 @@ def _add_pagination_filter_at_node( else: new_directives.append(directive) new_directives.append(directive_to_add) - new_selection_ast.directives = FrozenList(new_directives) - new_selections.append(new_selection_ast) + new_field_ast.directives = FrozenList(new_directives) + new_selections.append(new_field_ast) # If field didn't exist, create it and add the new directive to it. if not found_field: @@ -214,15 +219,20 @@ def _add_pagination_filter_at_node( return new_ast, new_parameters +FieldOrFragmentOrDefinitionT = TypeVar( + "FieldOrFragmentOrDefinitionT", FieldNode, InlineFragmentNode, OperationDefinitionNode +) + + def _add_pagination_filter_recursively( query_analysis: QueryPlanningAnalysis, - node_ast: SelectionNode, + node_ast: FieldOrFragmentOrDefinitionT, full_query_path: VertexPath, query_path: VertexPath, pagination_field: str, directive_to_add: DirectiveNode, extended_parameters: Dict[str, Any], -) -> Tuple[SelectionNode, Dict[str, Any]]: +) -> Tuple[FieldOrFragmentOrDefinitionT, Dict[str, Any]]: """Add the filter to the target field, returning a query and its new parameters. Args: @@ -241,12 +251,16 @@ def _add_pagination_filter_recursively( the same operation removed. new_parameters: The parameters to use with the new_ast """ - if not isinstance(node_ast, (FieldNode, InlineFragmentNode, OperationDefinitionNode)): - raise AssertionError( - f'Input AST is of type "{type(node_ast).__name__}", which should not be a selection.' - ) - if len(query_path) == 0: + if isinstance(node_ast, OperationDefinitionNode): + # N.B.: We cannot use assert_selection_is_a_field_or_inline_fragment_node() here, since + # that function's return type annotation is a Union type that isn't compatible + # with the generic type signature of _add_pagination_filter_at_node(). + raise AssertionError( + f"Unexpectedly encountered an OperationDefinitionNode while expecting either a " + f"FieldNode or InlineFragmentNode. This is a bug. Node value: {node_ast}" + ) + return _add_pagination_filter_at_node( query_analysis, full_query_path, @@ -259,23 +273,50 @@ def _add_pagination_filter_recursively( if node_ast.selection_set is None: raise AssertionError(f"Invalid query path {query_path} {node_ast}.") + # Handle the case if a node contains only a type coercion. In that case, this node contains + # no fields, so we recurse into the type coercion and recursively apply pagination there. + if len(node_ast.selection_set.selections) == 1: + inner_selection_ast = get_only_selection_from_ast(node_ast, AssertionError) + if isinstance(inner_selection_ast, InlineFragmentNode): + # This node contains only a type coercion. Recurse to its contents and return early. + new_inner_selection_ast, new_parameters = _add_pagination_filter_recursively( + query_analysis, + inner_selection_ast, + full_query_path, + query_path, + pagination_field, + directive_to_add, + extended_parameters, + ) + + new_ast = copy(node_ast) + new_ast.selection_set = SelectionSetNode(selections=[new_inner_selection_ast]) + return new_ast, new_parameters + else: + # No type coercion here -- only field selections. Fall through to the code below. + pass + found_field = False new_selections = [] for selection_ast in node_ast.selection_set.selections: - new_selection_ast = selection_ast - field_name = get_ast_field_name(selection_ast) + # At this point, we know that all selections within this node are of FieldNode type, since + # we handled the InlineFragment case above. Assert that our selection is a FieldNode. + field_ast = assert_selection_is_a_field_node(selection_ast) + + new_field_ast = field_ast + field_name = get_ast_field_name(field_ast) if field_name == query_path[0]: found_field = True - new_selection_ast, new_parameters = _add_pagination_filter_recursively( + new_field_ast, new_parameters = _add_pagination_filter_recursively( query_analysis, - selection_ast, + field_ast, full_query_path, query_path[1:], pagination_field, directive_to_add, extended_parameters, ) - new_selections.append(new_selection_ast) + new_selections.append(new_field_ast) if not found_field: raise AssertionError(f"Invalid query path {query_path} {node_ast}.") From acc53daf8194f68a1f5f5a0d12c674e6873a9a60 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Tue, 11 Aug 2020 12:23:05 -0400 Subject: [PATCH 2/6] Tighten mypy.ini with help from typing-copilot. --- mypy.ini | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/mypy.ini b/mypy.ini index 458108cbf..67eb6dbc6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -14,11 +14,6 @@ warn_unused_ignores = True # First party per-module rule relaxations -[mypy-graphql_compiler.ast_manipulation.*] -check_untyped_defs = False -disallow_untyped_calls = False -disallow_untyped_defs = False - [mypy-graphql_compiler.compiler.blocks.*] disallow_incomplete_defs = False disallow_untyped_defs = False @@ -163,15 +158,9 @@ disallow_untyped_calls = False [mypy-graphql_compiler.query_formatting.match_formatting.*] disallow_untyped_calls = False -[mypy-graphql_compiler.query_pagination.pagination_planning.*] -disallow_untyped_calls = False - [mypy-graphql_compiler.query_pagination.parameter_generator.*] disallow_untyped_calls = False -[mypy-graphql_compiler.query_pagination.query_parameterizer.*] -disallow_untyped_calls = False - [mypy-graphql_compiler.schema_generation.graphql_schema.*] check_untyped_defs = False disallow_incomplete_defs = False From 14fc32c8910e3acac9628670bb9e3a02058df6d6 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Tue, 11 Aug 2020 16:37:20 -0400 Subject: [PATCH 3/6] Fix lint. --- graphql_compiler/query_pagination/query_parameterizer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/graphql_compiler/query_pagination/query_parameterizer.py b/graphql_compiler/query_pagination/query_parameterizer.py index 3690fd95e..be8ab692b 100644 --- a/graphql_compiler/query_pagination/query_parameterizer.py +++ b/graphql_compiler/query_pagination/query_parameterizer.py @@ -1,7 +1,7 @@ # Copyright 2019-present Kensho Technologies, LLC. from copy import copy import logging -from typing import Any, Dict, List, Set, Tuple, TypeVar, Union, cast +from typing import Any, Dict, List, Set, Tuple, TypeVar, cast from graphql import print_ast from graphql.language.ast import ( @@ -13,7 +13,6 @@ ListValueNode, NameNode, OperationDefinitionNode, - SelectionNode, SelectionSetNode, StringValueNode, ) @@ -21,7 +20,6 @@ from ..ast_manipulation import ( assert_selection_is_a_field_node, - assert_selection_is_a_field_or_inline_fragment_node, get_ast_field_name, get_only_query_definition, get_only_selection_from_ast, From ee7321674d11070595251e61a9c06926f0efb178 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Tue, 18 Aug 2020 16:49:04 -0400 Subject: [PATCH 4/6] Add explicit checked cast functions. --- graphql_compiler/ast_manipulation.py | 32 +------------------ graphql_compiler/global_utils.py | 29 ++++++++++++++++- .../query_pagination/pagination_planning.py | 9 +++--- .../query_pagination/query_parameterizer.py | 20 ++++++++---- graphql_compiler/tests/test_global_utils.py | 22 ++++++++++++- 5 files changed, 68 insertions(+), 44 deletions(-) diff --git a/graphql_compiler/ast_manipulation.py b/graphql_compiler/ast_manipulation.py index 221f17f22..ccc23f1e9 100644 --- a/graphql_compiler/ast_manipulation.py +++ b/graphql_compiler/ast_manipulation.py @@ -1,5 +1,5 @@ # Copyright 2019-present Kensho Technologies, LLC. -from typing import Any, Optional, Type, Union +from typing import Optional, Type, Union from graphql.error import GraphQLSyntaxError from graphql.language.ast import ( @@ -154,33 +154,3 @@ def get_ast_with_non_null_and_list_stripped(ast: TypeNode) -> NamedTypeNode: ) return ast - - -def assert_selection_is_a_field_node(ast: Any) -> FieldNode: - """Return the input value, if it is indeed a FieldNode, or raise AssertionError otherwise.""" - # N.B.: Using Any as the input type, since the inputs here are a variety of generics, unions, - # and a number of AST types. Since this function asserts the type of the input anyway, - # this is not a concern. - if not isinstance(ast, FieldNode): - raise AssertionError( - f"Expected AST to be a FieldNode, but instead found {type(ast).__name__}. " - f"This is a bug. Node value: {ast}" - ) - - return ast - - -def assert_selection_is_a_field_or_inline_fragment_node( - ast: Any, -) -> Union[FieldNode, InlineFragmentNode]: - """Return the input FieldNode or InlineFragmentNode, or raise AssertionError otherwise.""" - # N.B.: Using Any as the input type, since the inputs here are a variety of generics, unions, - # and a number of AST types. Since this function asserts the type of the input anyway, - # this is not a concern. - if not isinstance(ast, (FieldNode, InlineFragmentNode)): - raise AssertionError( - f"Expected AST to be a FieldNode or InlineFragmentNode, but instead " - f"found {type(ast).__name__}. This is a bug. Node value: {ast}" - ) - - return ast diff --git a/graphql_compiler/global_utils.py b/graphql_compiler/global_utils.py index 4b4c1ac75..6b519cb8b 100755 --- a/graphql_compiler/global_utils.py +++ b/graphql_compiler/global_utils.py @@ -1,6 +1,6 @@ # Copyright 2017-present Kensho Technologies, LLC. from dataclasses import dataclass -from typing import Any, Dict, NamedTuple, Set, Tuple, Type, TypeVar +from typing import Any, Dict, NamedTuple, Set, Tuple, Type, TypeVar, Union from graphql import DocumentNode, GraphQLList, GraphQLNamedType, GraphQLNonNull, GraphQLType from graphql.language.printer import print_ast @@ -109,3 +109,30 @@ def assert_set_equality(set1: Set[Any], set2: Set[Any]) -> None: if diff2: error_message_list.append(f"Keys in the second set but not the first: {diff2}.") raise AssertionError(" ".join(error_message_list)) + + +T_A = TypeVar("T_A") +T_B = TypeVar("T_B") + + +def checked_cast(target_type: Type[T_A], value: Any) -> T_A: + """Assert that the value is of the given type; checked version of typing.cast().""" + if not isinstance(value, target_type): + raise AssertionError( + f"Expected value {value} to be an instance of {target_type}, but that was unexpectedly " + f"not the case. Its current type is {type(value).__name__}." + ) + + return value + + +def checked_cast_to_union2( + target_types: Tuple[Type[T_A], Type[T_B]], value: Any, +) -> Union[T_A, T_B]: + if not isinstance(value, target_types): + raise AssertionError( + f"Expected value {value} to be an instance of one of {target_types}, but that " + f"was unexpectedly not the case. Its current type is {type(value).__name__}." + ) + + return value diff --git a/graphql_compiler/query_pagination/pagination_planning.py b/graphql_compiler/query_pagination/pagination_planning.py index 0e5ff4b24..260b92a66 100644 --- a/graphql_compiler/query_pagination/pagination_planning.py +++ b/graphql_compiler/query_pagination/pagination_planning.py @@ -3,15 +3,16 @@ from dataclasses import dataclass, field from typing import Tuple +from graphql import FieldNode + from ..ast_manipulation import ( - assert_selection_is_a_field_node, get_ast_field_name, get_only_query_definition, get_only_selection_from_ast, ) from ..cost_estimation.analysis import QueryPlanningAnalysis from ..exceptions import GraphQLError -from ..global_utils import PropertyPath +from ..global_utils import PropertyPath, checked_cast @dataclass @@ -127,8 +128,8 @@ def get_pagination_plan( # TODO(bojanserafimov): Make a better pagination plan. A non-root vertex might have a # higher pagination capacity than the root does. - root_ast_node = assert_selection_is_a_field_node( - get_only_selection_from_ast(definition_ast, GraphQLError) + root_ast_node = checked_cast( + FieldNode, get_only_selection_from_ast(definition_ast, GraphQLError) ) root_vertex_name = get_ast_field_name(root_ast_node) pagination_node = root_vertex_name diff --git a/graphql_compiler/query_pagination/query_parameterizer.py b/graphql_compiler/query_pagination/query_parameterizer.py index be8ab692b..f8a118c5f 100644 --- a/graphql_compiler/query_pagination/query_parameterizer.py +++ b/graphql_compiler/query_pagination/query_parameterizer.py @@ -19,7 +19,6 @@ from graphql.pyutils import FrozenList from ..ast_manipulation import ( - assert_selection_is_a_field_node, get_ast_field_name, get_only_query_definition, get_only_selection_from_ast, @@ -28,7 +27,12 @@ from ..cost_estimation.analysis import QueryPlanningAnalysis from ..cost_estimation.int_value_conversion import convert_field_value_to_int from ..exceptions import GraphQLError -from ..global_utils import ASTWithParameters, PropertyPath, VertexPath +from ..global_utils import ( + ASTWithParameters, + PropertyPath, + VertexPath, + checked_cast, +) from .pagination_planning import VertexPartitionPlan @@ -161,7 +165,7 @@ def _add_pagination_filter_at_node( new_selections = [] found_field = False for selection_ast in node_ast.selection_set.selections: - field_ast = assert_selection_is_a_field_node(selection_ast) + field_ast = checked_cast(FieldNode, selection_ast) new_field_ast = field_ast field_name = get_ast_field_name(field_ast) if field_name == pagination_field: @@ -251,9 +255,11 @@ def _add_pagination_filter_recursively( """ if len(query_path) == 0: if isinstance(node_ast, OperationDefinitionNode): - # N.B.: We cannot use assert_selection_is_a_field_or_inline_fragment_node() here, since - # that function's return type annotation is a Union type that isn't compatible - # with the generic type signature of _add_pagination_filter_at_node(). + # N.B.: We can't use checked_cast_to_union2() here, because we need to narrow a generic + # bound to another generic bound: FieldOrFragmentOrDefinitionT -> FieldOrFragmentT + # Generics in Python are erased at runtime, so we have to do this by "subtracting" + # out the OperationDefinitionNode, which is the difference between the two + # generic types. raise AssertionError( f"Unexpectedly encountered an OperationDefinitionNode while expecting either a " f"FieldNode or InlineFragmentNode. This is a bug. Node value: {node_ast}" @@ -299,7 +305,7 @@ def _add_pagination_filter_recursively( for selection_ast in node_ast.selection_set.selections: # At this point, we know that all selections within this node are of FieldNode type, since # we handled the InlineFragment case above. Assert that our selection is a FieldNode. - field_ast = assert_selection_is_a_field_node(selection_ast) + field_ast = checked_cast(FieldNode, selection_ast) new_field_ast = field_ast field_name = get_ast_field_name(field_ast) diff --git a/graphql_compiler/tests/test_global_utils.py b/graphql_compiler/tests/test_global_utils.py index 7e42fd2fc..c7ef05352 100644 --- a/graphql_compiler/tests/test_global_utils.py +++ b/graphql_compiler/tests/test_global_utils.py @@ -1,7 +1,7 @@ # Copyright 2020-present Kensho Technologies, LLC. import unittest -from ..global_utils import assert_set_equality +from ..global_utils import assert_set_equality, checked_cast, checked_cast_to_union2 class GlobalUtilTests(unittest.TestCase): @@ -24,3 +24,23 @@ def test_assert_equality(self) -> None: # Different types with self.assertRaises(AssertionError): assert_set_equality({"a"}, {1}) + + def test_checked_cast(self) -> None: + checked_cast(int, 123) + checked_cast(str, "foo") + + with self.assertRaises(AssertionError): + checked_cast(int, "foo") + + with self.assertRaises(AssertionError): + checked_cast(int, None) + + def test_checked_cast_to_union2(self) -> None: + checked_cast_to_union2((int, str), 123) + checked_cast_to_union2((int, str), "foo") + + with self.assertRaises(AssertionError): + checked_cast_to_union2((int, str), None) + + with self.assertRaises(AssertionError): + checked_cast_to_union2((bool, str), 123) From e7c7022a63b7611ef8b2b4028b8eb783fdac98d4 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Tue, 18 Aug 2020 16:53:33 -0400 Subject: [PATCH 5/6] Add docstring and note. --- graphql_compiler/global_utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/graphql_compiler/global_utils.py b/graphql_compiler/global_utils.py index 6b519cb8b..d022c79f8 100755 --- a/graphql_compiler/global_utils.py +++ b/graphql_compiler/global_utils.py @@ -129,6 +129,10 @@ def checked_cast(target_type: Type[T_A], value: Any) -> T_A: def checked_cast_to_union2( target_types: Tuple[Type[T_A], Type[T_B]], value: Any, ) -> Union[T_A, T_B]: + """Assert that the value is one of the two specified types.""" + # N.B.: If we ever need a union of more than two types, feel free to make functions like + # checked_cast_to_union3() or checked_cast_to_union4() as necessary, so long as you + # cover them with tests on par with the ones for this function. if not isinstance(value, target_types): raise AssertionError( f"Expected value {value} to be an instance of one of {target_types}, but that " From 2cf340aded5cd6303028b415f495858c063a2ce8 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Tue, 29 Sep 2020 10:29:29 -0400 Subject: [PATCH 6/6] Fix lint. --- graphql_compiler/global_utils.py | 3 ++- graphql_compiler/query_pagination/query_parameterizer.py | 7 +------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/graphql_compiler/global_utils.py b/graphql_compiler/global_utils.py index d022c79f8..24a30babd 100755 --- a/graphql_compiler/global_utils.py +++ b/graphql_compiler/global_utils.py @@ -127,7 +127,8 @@ def checked_cast(target_type: Type[T_A], value: Any) -> T_A: def checked_cast_to_union2( - target_types: Tuple[Type[T_A], Type[T_B]], value: Any, + target_types: Tuple[Type[T_A], Type[T_B]], + value: Any, ) -> Union[T_A, T_B]: """Assert that the value is one of the two specified types.""" # N.B.: If we ever need a union of more than two types, feel free to make functions like diff --git a/graphql_compiler/query_pagination/query_parameterizer.py b/graphql_compiler/query_pagination/query_parameterizer.py index cb4ee65fe..84d2a7cd4 100644 --- a/graphql_compiler/query_pagination/query_parameterizer.py +++ b/graphql_compiler/query_pagination/query_parameterizer.py @@ -27,12 +27,7 @@ from ..cost_estimation.analysis import QueryPlanningAnalysis from ..cost_estimation.int_value_conversion import convert_field_value_to_int from ..exceptions import GraphQLError -from ..global_utils import ( - ASTWithParameters, - PropertyPath, - VertexPath, - checked_cast, -) +from ..global_utils import ASTWithParameters, PropertyPath, VertexPath, checked_cast from .pagination_planning import VertexPartitionPlan