Skip to content
This repository was archived by the owner on Feb 6, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 82 additions & 13 deletions graphql_compiler/ast_manipulation.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,53 @@
# Copyright 2019-present Kensho Technologies, LLC.
from typing import Any, Optional, Type, Union

from graphql.error import GraphQLSyntaxError
from graphql.language.ast import (
DocumentNode,
FieldNode,
InlineFragmentNode,
ListTypeNode,
NamedTypeNode,
Node,
NonNullTypeNode,
OperationDefinitionNode,
OperationType,
SelectionNode,
TypeNode,
)
from graphql.language.parser import parse
from graphql.pyutils import FrozenList

from .exceptions import GraphQLParsingError


def get_ast_field_name(ast):
def get_ast_field_name(ast: FieldNode) -> str:
"""Return the field name for the given AST node."""
return ast.name.value


def get_ast_field_name_or_none(ast):
def get_ast_field_name_or_none(ast: Union[FieldNode, InlineFragmentNode]) -> Optional[str]:
"""Return the field name for the AST node, or None if the AST is an InlineFragment."""
if isinstance(ast, InlineFragmentNode):
return None
return get_ast_field_name(ast)


def get_human_friendly_ast_field_name(ast):
def get_human_friendly_ast_field_name(ast: Node) -> str:
"""Return a human-friendly name for the AST node, suitable for error messages."""
if isinstance(ast, InlineFragmentNode):
return "type coercion to {}".format(ast.type_condition)
elif isinstance(ast, OperationDefinitionNode):
return "{} operation definition".format(ast.operation)

return get_ast_field_name(ast)
elif isinstance(ast, FieldNode):
return get_ast_field_name(ast)
else:
# Fall back to Node's __repr__() method.
# If we need more information for a specific type, we can add another branch in the if-elif.
return repr(ast)


def safe_parse_graphql(graphql_string):
def safe_parse_graphql(graphql_string: str) -> DocumentNode:
"""Return an AST representation of the given GraphQL input, reraising GraphQL library errors."""
try:
ast = parse(graphql_string)
Expand All @@ -45,7 +57,9 @@ def safe_parse_graphql(graphql_string):
return ast


def get_only_query_definition(document_ast, desired_error_type):
def get_only_query_definition(
document_ast: DocumentNode, desired_error_type: Type[Exception]
) -> OperationDefinitionNode:
"""Assert that the Document AST contains only a single definition for a query, and return it."""
if not isinstance(document_ast, DocumentNode) or not document_ast.definitions:
raise AssertionError(
Expand All @@ -59,6 +73,12 @@ def get_only_query_definition(document_ast, desired_error_type):
)

definition_ast = document_ast.definitions[0]
if not isinstance(definition_ast, OperationDefinitionNode):
raise desired_error_type(
f"Expected a query definition at the start of the GraphQL input, but found an "
f"unsupported and unrecognized definition instead: {definition_ast}"
)

if definition_ast.operation != OperationType.QUERY:
raise desired_error_type(
"Expected a GraphQL document with a single query definition, but instead found a "
Expand All @@ -70,9 +90,12 @@ def get_only_query_definition(document_ast, desired_error_type):
return definition_ast


def get_only_selection_from_ast(ast, desired_error_type):
def get_only_selection_from_ast(
ast: Union[FieldNode, InlineFragmentNode, OperationDefinitionNode],
desired_error_type: Type[Exception],
) -> SelectionNode:
"""Return the selected sub-ast, ensuring that there is precisely one."""
selections = [] if ast.selection_set is None else ast.selection_set.selections
selections = FrozenList([]) if ast.selection_set is None else ast.selection_set.selections

if len(selections) != 1:
ast_name = get_human_friendly_ast_field_name(ast)
Expand All @@ -96,22 +119,68 @@ def get_only_selection_from_ast(ast, desired_error_type):
return selections[0]


def get_ast_with_non_null_stripped(ast):
def get_ast_with_non_null_stripped(ast: TypeNode) -> Union[ListTypeNode, NamedTypeNode]:
"""Strip a NonNullType layer around the AST if there is one, return the underlying AST."""
result: TypeNode

if isinstance(ast, NonNullTypeNode):
stripped_ast = ast.type
if isinstance(stripped_ast, NonNullTypeNode):
raise AssertionError(
"NonNullType is unexpectedly found to wrap around another NonNullType in AST "
"{}, which is not allowed.".format(ast)
)
return stripped_ast
result = stripped_ast
else:
return ast
result = ast

if not isinstance(result, (ListTypeNode, NamedTypeNode)):
raise AssertionError(
f"Expected the result to be either a ListTypeNode or a NamedTypeNode, but instead "
f"found: {result}"
)

return result

def get_ast_with_non_null_and_list_stripped(ast):

def get_ast_with_non_null_and_list_stripped(ast: TypeNode) -> NamedTypeNode:
"""Strip any NonNullType or List layers around the AST, return the underlying AST."""
while isinstance(ast, (NonNullTypeNode, ListTypeNode)):
ast = ast.type

if not isinstance(ast, NamedTypeNode):
raise AssertionError(
f"Expected the AST value to be a NamedTypeNode, but unexpectedly instead found: {ast}"
)

return ast


def assert_selection_is_a_field_node(ast: Any) -> FieldNode:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's strange for an assert function to return, though I get why it has to.

In relation to mypy's cast function, this is just a safe_cast. If you don't find mypy's use of the word cast repulsive, this would be a good name for this function. We can even make it generic: safe_cast(value: Any, type: Type[T]) -> T.

I personally don't like mypy's naming of cast since there's no change of shape involved (like in metal casting). It should have been called unsafe_assume_type. But the function exists, there's nothing I can do about it, and calling our function safe_cast makes it a discoverable alternative to cast.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea — perhaps checked_cast as an alternative name, to make the answer to "how is this safer" a bit more explicit?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checked_cast() doesn't work with Union because Union is not a Type, so Type[T] doesn't match.

For bounded-length tuples, this is achievable:

def checked_cast_to_union2(target_types: Tuple[Type[T_A], Type[T_B]], value: Any) -> Union[T_A, T_B]:
    if not isinstance(value, target_types):
        raise AssertionError()

    return value

I don't know of a way to make target_types: Tuple[Type, ...] work and connect the types it contains to the types inside the Union result.

Given that we're somewhat unlikely to have unions of more than like 3 or so things, I'm fine with adding checked_cast_to_union2() and in the future potentially adding a checked_cast_to_union3() function that takes a 3-tuple as necessary.

What do you think?

"""Return the input value, if it is indeed a FieldNode, or raise AssertionError otherwise."""
# N.B.: Using Any as the input type, since the inputs here are a variety of generics, unions,
# and a number of AST types. Since this function asserts the type of the input anyway,
# this is not a concern.
if not isinstance(ast, FieldNode):
raise AssertionError(
f"Expected AST to be a FieldNode, but instead found {type(ast).__name__}. "
f"This is a bug. Node value: {ast}"
)

return ast


def assert_selection_is_a_field_or_inline_fragment_node(
ast: Any,
) -> Union[FieldNode, InlineFragmentNode]:
"""Return the input FieldNode or InlineFragmentNode, or raise AssertionError otherwise."""
# N.B.: Using Any as the input type, since the inputs here are a variety of generics, unions,
# and a number of AST types. Since this function asserts the type of the input anyway,
# this is not a concern.
if not isinstance(ast, (FieldNode, InlineFragmentNode)):
raise AssertionError(
f"Expected AST to be a FieldNode or InlineFragmentNode, but instead "
f"found {type(ast).__name__}. This is a bug. Node value: {ast}"
)

return ast
16 changes: 12 additions & 4 deletions graphql_compiler/query_pagination/pagination_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from dataclasses import dataclass, field
from typing import Tuple

from ..ast_manipulation import get_only_query_definition, get_only_selection_from_ast
from ..ast_manipulation import (
assert_selection_is_a_field_node,
get_ast_field_name,
get_only_query_definition,
get_only_selection_from_ast,
)
from ..cost_estimation.analysis import QueryPlanningAnalysis
from ..exceptions import GraphQLError
from ..global_utils import PropertyPath
Expand Down Expand Up @@ -122,8 +127,11 @@ def get_pagination_plan(

# TODO(bojanserafimov): Make a better pagination plan. A non-root vertex might have a
# higher pagination capacity than the root does.
root_node = get_only_selection_from_ast(definition_ast, GraphQLError).name.value
pagination_node = root_node
root_ast_node = assert_selection_is_a_field_node(
get_only_selection_from_ast(definition_ast, GraphQLError)
)
root_vertex_name = get_ast_field_name(root_ast_node)
pagination_node = root_vertex_name

# TODO(bojanserafimov): Remove pagination fields. The pagination planner is now smart enough
# to pick the best field for pagination based on the query. This is not
Expand All @@ -134,7 +142,7 @@ def get_pagination_plan(
return PaginationPlan(tuple()), (PaginationFieldNotSpecified(pagination_node),)

# Get the pagination capacity
vertex_path = (root_node,)
vertex_path = (root_vertex_name,)
property_path = PropertyPath(vertex_path, pagination_field)
capacity = query_analysis.pagination_capacities.get(property_path)
# If the pagination capacity is None, then there must be no quantiles for this property.
Expand Down
Loading