-
Notifications
You must be signed in to change notification settings - Fork 50
Improve type checking in query pagination and AST manipulation code. #898
base: main
Are you sure you want to change the base?
Changes from 3 commits
4cec241
acc53da
14fc32c
4c9909c
ee73216
e7c7022
21f3742
2cf340a
259d287
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,41 +1,53 @@ | ||
# Copyright 2019-present Kensho Technologies, LLC. | ||
from typing import Any, Optional, Type, Union | ||
|
||
from graphql.error import GraphQLSyntaxError | ||
from graphql.language.ast import ( | ||
DocumentNode, | ||
FieldNode, | ||
InlineFragmentNode, | ||
ListTypeNode, | ||
NamedTypeNode, | ||
Node, | ||
NonNullTypeNode, | ||
OperationDefinitionNode, | ||
OperationType, | ||
SelectionNode, | ||
TypeNode, | ||
) | ||
from graphql.language.parser import parse | ||
from graphql.pyutils import FrozenList | ||
|
||
from .exceptions import GraphQLParsingError | ||
|
||
|
||
def get_ast_field_name(ast): | ||
def get_ast_field_name(ast: FieldNode) -> str: | ||
"""Return the field name for the given AST node.""" | ||
return ast.name.value | ||
|
||
|
||
def get_ast_field_name_or_none(ast): | ||
def get_ast_field_name_or_none(ast: Union[FieldNode, InlineFragmentNode]) -> Optional[str]: | ||
"""Return the field name for the AST node, or None if the AST is an InlineFragment.""" | ||
if isinstance(ast, InlineFragmentNode): | ||
return None | ||
return get_ast_field_name(ast) | ||
|
||
|
||
def get_human_friendly_ast_field_name(ast): | ||
def get_human_friendly_ast_field_name(ast: Node) -> str: | ||
"""Return a human-friendly name for the AST node, suitable for error messages.""" | ||
if isinstance(ast, InlineFragmentNode): | ||
return "type coercion to {}".format(ast.type_condition) | ||
elif isinstance(ast, OperationDefinitionNode): | ||
return "{} operation definition".format(ast.operation) | ||
|
||
return get_ast_field_name(ast) | ||
elif isinstance(ast, FieldNode): | ||
return get_ast_field_name(ast) | ||
else: | ||
# Fall back to Node's __repr__() method. | ||
# If we need more information for a specific type, we can add another branch in the if-elif. | ||
return repr(ast) | ||
|
||
|
||
def safe_parse_graphql(graphql_string): | ||
def safe_parse_graphql(graphql_string: str) -> DocumentNode: | ||
"""Return an AST representation of the given GraphQL input, reraising GraphQL library errors.""" | ||
try: | ||
ast = parse(graphql_string) | ||
|
@@ -45,7 +57,9 @@ def safe_parse_graphql(graphql_string): | |
return ast | ||
|
||
|
||
def get_only_query_definition(document_ast, desired_error_type): | ||
def get_only_query_definition( | ||
document_ast: DocumentNode, desired_error_type: Type[Exception] | ||
) -> OperationDefinitionNode: | ||
"""Assert that the Document AST contains only a single definition for a query, and return it.""" | ||
if not isinstance(document_ast, DocumentNode) or not document_ast.definitions: | ||
raise AssertionError( | ||
|
@@ -59,6 +73,12 @@ def get_only_query_definition(document_ast, desired_error_type): | |
) | ||
|
||
definition_ast = document_ast.definitions[0] | ||
if not isinstance(definition_ast, OperationDefinitionNode): | ||
raise desired_error_type( | ||
f"Expected a query definition at the start of the GraphQL input, but found an " | ||
f"unsupported and unrecognized definition instead: {definition_ast}" | ||
) | ||
|
||
if definition_ast.operation != OperationType.QUERY: | ||
raise desired_error_type( | ||
"Expected a GraphQL document with a single query definition, but instead found a " | ||
|
@@ -70,9 +90,12 @@ def get_only_query_definition(document_ast, desired_error_type): | |
return definition_ast | ||
|
||
|
||
def get_only_selection_from_ast(ast, desired_error_type): | ||
def get_only_selection_from_ast( | ||
ast: Union[FieldNode, InlineFragmentNode, OperationDefinitionNode], | ||
desired_error_type: Type[Exception], | ||
) -> SelectionNode: | ||
"""Return the selected sub-ast, ensuring that there is precisely one.""" | ||
selections = [] if ast.selection_set is None else ast.selection_set.selections | ||
selections = FrozenList([]) if ast.selection_set is None else ast.selection_set.selections | ||
|
||
if len(selections) != 1: | ||
ast_name = get_human_friendly_ast_field_name(ast) | ||
|
@@ -96,22 +119,68 @@ def get_only_selection_from_ast(ast, desired_error_type): | |
return selections[0] | ||
|
||
|
||
def get_ast_with_non_null_stripped(ast): | ||
def get_ast_with_non_null_stripped(ast: TypeNode) -> Union[ListTypeNode, NamedTypeNode]: | ||
"""Strip a NonNullType layer around the AST if there is one, return the underlying AST.""" | ||
result: TypeNode | ||
|
||
if isinstance(ast, NonNullTypeNode): | ||
stripped_ast = ast.type | ||
if isinstance(stripped_ast, NonNullTypeNode): | ||
raise AssertionError( | ||
"NonNullType is unexpectedly found to wrap around another NonNullType in AST " | ||
"{}, which is not allowed.".format(ast) | ||
) | ||
return stripped_ast | ||
result = stripped_ast | ||
else: | ||
return ast | ||
result = ast | ||
|
||
if not isinstance(result, (ListTypeNode, NamedTypeNode)): | ||
raise AssertionError( | ||
f"Expected the result to be either a ListTypeNode or a NamedTypeNode, but instead " | ||
f"found: {result}" | ||
) | ||
|
||
return result | ||
|
||
def get_ast_with_non_null_and_list_stripped(ast): | ||
|
||
def get_ast_with_non_null_and_list_stripped(ast: TypeNode) -> NamedTypeNode: | ||
"""Strip any NonNullType or List layers around the AST, return the underlying AST.""" | ||
while isinstance(ast, (NonNullTypeNode, ListTypeNode)): | ||
ast = ast.type | ||
|
||
if not isinstance(ast, NamedTypeNode): | ||
raise AssertionError( | ||
f"Expected the AST value to be a NamedTypeNode, but unexpectedly instead found: {ast}" | ||
) | ||
|
||
return ast | ||
|
||
|
||
def assert_selection_is_a_field_node(ast: Any) -> FieldNode: | ||
|
||
"""Return the input value, if it is indeed a FieldNode, or raise AssertionError otherwise.""" | ||
# N.B.: Using Any as the input type, since the inputs here are a variety of generics, unions, | ||
# and a number of AST types. Since this function asserts the type of the input anyway, | ||
# this is not a concern. | ||
if not isinstance(ast, FieldNode): | ||
raise AssertionError( | ||
f"Expected AST to be a FieldNode, but instead found {type(ast).__name__}. " | ||
f"This is a bug. Node value: {ast}" | ||
) | ||
|
||
return ast | ||
|
||
|
||
def assert_selection_is_a_field_or_inline_fragment_node( | ||
ast: Any, | ||
) -> Union[FieldNode, InlineFragmentNode]: | ||
"""Return the input FieldNode or InlineFragmentNode, or raise AssertionError otherwise.""" | ||
# N.B.: Using Any as the input type, since the inputs here are a variety of generics, unions, | ||
# and a number of AST types. Since this function asserts the type of the input anyway, | ||
# this is not a concern. | ||
if not isinstance(ast, (FieldNode, InlineFragmentNode)): | ||
raise AssertionError( | ||
f"Expected AST to be a FieldNode or InlineFragmentNode, but instead " | ||
f"found {type(ast).__name__}. This is a bug. Node value: {ast}" | ||
) | ||
|
||
return ast |
Uh oh!
There was an error while loading. Please reload this page.