Skip to content
This repository was archived by the owner on Feb 6, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 51 additions & 12 deletions graphql_compiler/ast_manipulation.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,50 @@
# Copyright 2019-present Kensho Technologies, LLC.
from typing import Optional, Type, Union

from graphql.error import GraphQLSyntaxError
from graphql.language.ast import (
DocumentNode,
FieldNode,
InlineFragmentNode,
ListTypeNode,
NamedTypeNode,
Node,
NonNullTypeNode,
OperationDefinitionNode,
OperationType,
SelectionNode,
TypeNode,
)
from graphql.language.parser import parse
from graphql.pyutils import FrozenList

from .exceptions import GraphQLParsingError


def get_ast_field_name(ast):
def get_ast_field_name(ast: FieldNode) -> str:
"""Return the field name for the given AST node."""
return ast.name.value


def get_ast_field_name_or_none(ast):
def get_ast_field_name_or_none(ast: Union[FieldNode, InlineFragmentNode]) -> Optional[str]:
"""Return the field name for the AST node, or None if the AST is an InlineFragment."""
if isinstance(ast, InlineFragmentNode):
return None
return get_ast_field_name(ast)


def get_human_friendly_ast_field_name(ast):
def get_human_friendly_ast_field_name(ast: Node) -> str:
"""Return a human-friendly name for the AST node, suitable for error messages."""
if isinstance(ast, InlineFragmentNode):
return "type coercion to {}".format(ast.type_condition)
elif isinstance(ast, OperationDefinitionNode):
return "{} operation definition".format(ast.operation)

return get_ast_field_name(ast)
elif isinstance(ast, FieldNode):
return get_ast_field_name(ast)
else:
# Fall back to Node's __repr__() method.
# If we need more information for a specific type, we can add another branch in the if-elif.
return repr(ast)


def safe_parse_graphql(graphql_string: str) -> DocumentNode:
Expand All @@ -45,7 +57,9 @@ def safe_parse_graphql(graphql_string: str) -> DocumentNode:
return ast


def get_only_query_definition(document_ast, desired_error_type):
def get_only_query_definition(
document_ast: DocumentNode, desired_error_type: Type[Exception]
) -> OperationDefinitionNode:
"""Assert that the Document AST contains only a single definition for a query, and return it."""
if not isinstance(document_ast, DocumentNode) or not document_ast.definitions:
raise AssertionError(
Expand All @@ -59,6 +73,12 @@ def get_only_query_definition(document_ast, desired_error_type):
)

definition_ast = document_ast.definitions[0]
if not isinstance(definition_ast, OperationDefinitionNode):
raise desired_error_type(
f"Expected a query definition at the start of the GraphQL input, but found an "
f"unsupported and unrecognized definition instead: {definition_ast}"
)

if definition_ast.operation != OperationType.QUERY:
raise desired_error_type(
"Expected a GraphQL document with a single query definition, but instead found a "
Expand All @@ -70,9 +90,12 @@ def get_only_query_definition(document_ast, desired_error_type):
return definition_ast


def get_only_selection_from_ast(ast, desired_error_type):
def get_only_selection_from_ast(
ast: Union[FieldNode, InlineFragmentNode, OperationDefinitionNode],
desired_error_type: Type[Exception],
) -> SelectionNode:
"""Return the selected sub-ast, ensuring that there is precisely one."""
selections = [] if ast.selection_set is None else ast.selection_set.selections
selections = FrozenList([]) if ast.selection_set is None else ast.selection_set.selections

if len(selections) != 1:
ast_name = get_human_friendly_ast_field_name(ast)
Expand All @@ -96,22 +119,38 @@ def get_only_selection_from_ast(ast, desired_error_type):
return selections[0]


def get_ast_with_non_null_stripped(ast):
def get_ast_with_non_null_stripped(ast: TypeNode) -> Union[ListTypeNode, NamedTypeNode]:
"""Strip a NonNullType layer around the AST if there is one, return the underlying AST."""
result: TypeNode

if isinstance(ast, NonNullTypeNode):
stripped_ast = ast.type
if isinstance(stripped_ast, NonNullTypeNode):
raise AssertionError(
"NonNullType is unexpectedly found to wrap around another NonNullType in AST "
"{}, which is not allowed.".format(ast)
)
return stripped_ast
result = stripped_ast
else:
return ast
result = ast

if not isinstance(result, (ListTypeNode, NamedTypeNode)):
raise AssertionError(
f"Expected the result to be either a ListTypeNode or a NamedTypeNode, but instead "
f"found: {result}"
)

return result


def get_ast_with_non_null_and_list_stripped(ast):
def get_ast_with_non_null_and_list_stripped(ast: TypeNode) -> NamedTypeNode:
"""Strip any NonNullType or List layers around the AST, return the underlying AST."""
while isinstance(ast, (NonNullTypeNode, ListTypeNode)):
ast = ast.type

if not isinstance(ast, NamedTypeNode):
raise AssertionError(
f"Expected the AST value to be a NamedTypeNode, but unexpectedly instead found: {ast}"
)

return ast
34 changes: 33 additions & 1 deletion graphql_compiler/global_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright 2017-present Kensho Technologies, LLC.
from dataclasses import dataclass
from typing import Any, Dict, NamedTuple, Set, Tuple, Type, TypeVar
from typing import Any, Dict, NamedTuple, Set, Tuple, Type, TypeVar, Union

from graphql import DocumentNode, GraphQLList, GraphQLNamedType, GraphQLNonNull, GraphQLType
from graphql.language.printer import print_ast
Expand Down Expand Up @@ -109,3 +109,35 @@ def assert_set_equality(set1: Set[Any], set2: Set[Any]) -> None:
if diff2:
error_message_list.append(f"Keys in the second set but not the first: {diff2}.")
raise AssertionError(" ".join(error_message_list))


T_A = TypeVar("T_A")
T_B = TypeVar("T_B")


def checked_cast(target_type: Type[T_A], value: Any) -> T_A:
"""Assert that the value is of the given type; checked version of typing.cast()."""
if not isinstance(value, target_type):
raise AssertionError(
f"Expected value {value} to be an instance of {target_type}, but that was unexpectedly "
f"not the case. Its current type is {type(value).__name__}."
)

return value


def checked_cast_to_union2(
target_types: Tuple[Type[T_A], Type[T_B]],
value: Any,
) -> Union[T_A, T_B]:
"""Assert that the value is one of the two specified types."""
# N.B.: If we ever need a union of more than two types, feel free to make functions like
# checked_cast_to_union3() or checked_cast_to_union4() as necessary, so long as you
# cover them with tests on par with the ones for this function.
if not isinstance(value, target_types):
raise AssertionError(
f"Expected value {value} to be an instance of one of {target_types}, but that "
f"was unexpectedly not the case. Its current type is {type(value).__name__}."
)

return value
19 changes: 14 additions & 5 deletions graphql_compiler/query_pagination/pagination_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,16 @@
from dataclasses import dataclass, field
from typing import Tuple

from ..ast_manipulation import get_only_query_definition, get_only_selection_from_ast
from graphql import FieldNode

from ..ast_manipulation import (
get_ast_field_name,
get_only_query_definition,
get_only_selection_from_ast,
)
from ..cost_estimation.analysis import QueryPlanningAnalysis
from ..exceptions import GraphQLError
from ..global_utils import PropertyPath
from ..global_utils import PropertyPath, checked_cast


@dataclass
Expand Down Expand Up @@ -122,8 +128,11 @@ def get_pagination_plan(

# TODO(bojanserafimov): Make a better pagination plan. A non-root vertex might have a
# higher pagination capacity than the root does.
root_node = get_only_selection_from_ast(definition_ast, GraphQLError).name.value
pagination_node = root_node
root_ast_node = checked_cast(
FieldNode, get_only_selection_from_ast(definition_ast, GraphQLError)
)
root_vertex_name = get_ast_field_name(root_ast_node)
pagination_node = root_vertex_name

# TODO(bojanserafimov): Remove pagination fields. The pagination planner is now smart enough
# to pick the best field for pagination based on the query. This is not
Expand All @@ -134,7 +143,7 @@ def get_pagination_plan(
return PaginationPlan(tuple()), (PaginationFieldNotSpecified(pagination_node),)

# Get the pagination capacity
vertex_path = (root_node,)
vertex_path = (root_vertex_name,)
property_path = PropertyPath(vertex_path, pagination_field)
capacity = query_analysis.pagination_capacities.get(property_path)
# If the pagination capacity is None, then there must be no quantiles for this property.
Expand Down
Loading