Skip to content
This repository was archived by the owner on Feb 6, 2025. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from ..ast_manipulation import get_only_query_definition
from ..exceptions import GraphQLValidationError
from ..schema import FilterDirective, OutputDirective
from .utils import get_query_runtime_arguments


SubQueryPlan = namedtuple(
'SubQueryPlan', (
'plan_id', # int, unique identifier for this sub-plan
'query_ast', # Document, representing a piece of the overall query with directives added
'schema_id', # str, identifying the schema that this query piece targets
'parent_query_plan', # SubQueryPlan, the query that the current query depends on
Expand All @@ -26,6 +28,9 @@
OutputJoinDescriptor = namedtuple(
'OutputJoinDescriptor', (
'output_names', # Tuple[str, str], (parent output name, child output name)
'child_query_plan', # SubQueryPlan, the sub-plan node for which the join happens
# between it and its parent sub-plan

# May be expanded to have more attributes, e.g. is_optional, describing how the join
# should be made
)
Expand Down Expand Up @@ -67,13 +72,14 @@ def make_query_plan(root_sub_query_node, intermediate_output_names):
output_join_descriptors = []

root_sub_query_plan = SubQueryPlan(
plan_id=0,
query_ast=root_sub_query_node.query_ast,
schema_id=root_sub_query_node.schema_id,
parent_query_plan=None,
child_query_plans=[],
)

_make_query_plan_recursive(root_sub_query_node, root_sub_query_plan, output_join_descriptors)
_make_query_plan_recursive(root_sub_query_node, root_sub_query_plan, output_join_descriptors, 1)

return QueryPlanDescriptor(
root_sub_query_plan=root_sub_query_plan,
Expand All @@ -82,7 +88,8 @@ def make_query_plan(root_sub_query_node, intermediate_output_names):
)


def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descriptors):
def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descriptors,
next_plan_id):
"""Recursively copy the structure of sub_query_node onto sub_query_plan.

For each child connection contained in sub_query_node, create a new SubQueryPlan for
Expand All @@ -96,7 +103,7 @@ def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descr
modified
output_join_descriptors: List[OutputJoinDescriptor], describing which outputs should be
joined and how

next_plan_id: int, the next available plan ID to use. IDs at and above this number are free.
"""
# Iterate through child connections of query node
for child_query_connection in sub_query_node.child_query_connections:
Expand All @@ -123,24 +130,27 @@ def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descr

# Create new SubQueryPlan for child
child_sub_query_plan = SubQueryPlan(
plan_id=next_plan_id,
query_ast=new_child_query_ast,
schema_id=child_sub_query_node.schema_id,
parent_query_plan=sub_query_plan,
child_query_plans=[],
)
next_plan_id += 1

# Add new SubQueryPlan to parent's child list
sub_query_plan.child_query_plans.append(child_sub_query_plan)

# Add information about this edge
new_output_join_descriptor = OutputJoinDescriptor(
output_names=(parent_out_name, child_out_name),
child_query_plan=child_sub_query_plan,
)
output_join_descriptors.append(new_output_join_descriptor)

# Recursively repeat on child SubQueryPlans
_make_query_plan_recursive(
child_sub_query_node, child_sub_query_plan, output_join_descriptors
child_sub_query_node, child_sub_query_plan, output_join_descriptors, next_plan_id
)


Expand Down Expand Up @@ -253,15 +263,25 @@ def print_query_plan(query_plan_descriptor, indentation_depth=4):
line_separation = u'\n' + u' ' * indentation_depth * depth
query_plan_strings.append(line_separation)

query_str = u'Execute in schema named "{}":\n'.format(query_plan.schema_id)
query_str = u'Execute subplan ID {} in schema named "{}":\n'.format(
query_plan.plan_id, query_plan.schema_id)
query_str += print_ast(query_plan.query_ast)
query_str = query_str.replace(u'\n', line_separation)
query_plan_strings.append(query_str)

query_plan_strings.append(u'\n\nJoin together outputs as follows: ')
query_plan_strings.append(str(query_plan_descriptor.output_join_descriptors))
query_plan_strings.append(str([
' '.join([
str(descriptor.output_names),
'between subplan IDs',
str([
descriptor.child_query_plan.parent_query_plan.plan_id,
descriptor.child_query_plan.plan_id
])])
for descriptor in query_plan_descriptor.output_join_descriptors
]))
query_plan_strings.append(u'\n\nRemove the following outputs at the end: ')
query_plan_strings.append(str(query_plan_descriptor.intermediate_output_names) + u'\n')
query_plan_strings.append(str(set(query_plan_descriptor.intermediate_output_names)) + u'\n')

return ''.join(query_plan_strings)

Expand All @@ -276,3 +296,164 @@ def _get_plan_and_depth_in_dfs_order_helper(query_plan, depth):
)
return plan_and_depth_in_dfs_order
return _get_plan_and_depth_in_dfs_order_helper(query_plan, 0)


def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query_args):
"""Execute the given query plan and return the produced results."""
result_components_by_plan_id = {}

stitching_output_names_by_parent_plan_id = dict()
for join_descriptor in query_plan_descriptor.output_join_descriptors:
parent_plan_id = join_descriptor.child_query_plan.parent_query_plan.plan_id
stitching_output_names_by_parent_plan_id.setdefault(parent_plan_id, []).append(
join_descriptor.output_names)

full_query_args = dict(query_args)

plan_and_depth = _get_plan_and_depth_in_dfs_order(query_plan_descriptor.root_sub_query_plan)
Copy link
Collaborator

@bojanserafimov bojanserafimov Sep 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the code below depend on the order being a dfs order? In some very trivial cases we will have to start execution from the leafs. Just making sure we don't lock ourselves out from that execution plan with the code structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we should document better than we currently have is the distinction between the output of split_query() and the output of make_query_plan() -- I think that would clarify the situation here. split_query() tells us where we cross schemas, but does not specify an execution order. make_query_plan() specifies an order between the subqueries, and is free to specify any valid order -- and may choose an optimized order if it has access to statistics etc.

Given that, once a query plan is made, it's always executed from the root onward. We may in the future allow different styles of executors (e.g. DFS, BFS, async + parallel across children, etc.), but this executor function is just a simple sync DFS.


for query_plan, _ in plan_and_depth:
plan_id = query_plan.plan_id
schema_id = query_plan.schema_id

subquery_graphql = print_ast(query_plan.query_ast)

print('\n\n********* BEGIN *********\n')
print(subquery_graphql)

# HACK(predrag): Add proper error checking for missing arguments here.
# HACK(predrag): Don't bother running queries if the previous query's stitching outputs
# returned no values to pass to the next query.
subquery_args = {
argument_name: full_query_args[argument_name]
for argument_name in get_query_runtime_arguments(query_plan.query_ast)
}

print(subquery_args)

# Run the query and save the results.
execution_func = schema_id_to_execution_func[schema_id]
subquery_result = execution_func(subquery_graphql, subquery_args)
result_components_by_plan_id[plan_id] = subquery_result

print(subquery_result)

# Capture and record any values that will be used for stitching by other subqueries.
child_extra_output_names = {
# The .get() call is to handle the case of query plans with no children.
# They have no extra output values for their children, on account of having no children.
output_name
for output_name, _ in stitching_output_names_by_parent_plan_id.get(plan_id, [])
}
child_extra_output_values = {
# Make sure we deduplicate the values -- there's no point in running subqueries
# with duplicated runtime argument values.
output_name: set()
for output_name in child_extra_output_names
}
for subquery_row in subquery_result:
for output_name in child_extra_output_names:
# We intentionally discard None values -- None is never a foreign key value.
# This is standard in all relational systems as well.
output_value = subquery_row.get(output_name, None)
if output_value is not None:
child_extra_output_values[output_name].add(output_value)
# TODO(predrag): Use the "merge_disjoint_dicts" function here,
# there should never be any overlap here.
new_query_args = {
# Argument values cannot be sets, so we turn the sets back into lists.
output_argument_name: list(child_extra_output_values[output_argument_name])
for output_argument_name in child_extra_output_names
}
full_query_args.update(new_query_args)

print(new_query_args)
print('\n********** END ***********\n')

join_indexes_by_plan_id = _make_join_indexes(
query_plan_descriptor, result_components_by_plan_id)

joined_results = _join_results(
result_components_by_plan_id, join_indexes_by_plan_id,
result_components_by_plan_id[query_plan_descriptor.root_sub_query_plan.plan_id],
query_plan_descriptor.output_join_descriptors)

return _drop_intermediate_outputs(
query_plan_descriptor.intermediate_output_names, joined_results)


def _make_join_indexes(query_plan_descriptor, result_components_by_plan_id):
"""Return a dict from child plan id to a join index between its and its parents' rows."""
join_indexes_by_plan_id = dict()

for join_descriptor in query_plan_descriptor.output_join_descriptors:
child_plan_id = join_descriptor.child_query_plan.plan_id
_, child_output_name = join_descriptor.output_names

if child_plan_id in join_indexes_by_plan_id:
raise AssertionError('Unreachable code reached: {} {} {}'
.format(child_plan_id, join_indexes_by_plan_id,
query_plan_descriptor.output_join_descriptors))

join_indexes_by_plan_id[child_plan_id] = _make_join_index_for_output(
result_components_by_plan_id[child_plan_id], child_output_name)

return join_indexes_by_plan_id


def _make_join_index_for_output(results, join_output_name):
"""Return a dict of each value of the join column to a list of row indexes where it appears."""
print('making join index on column ', join_output_name)
print(results)

join_index = {}
for row_index, row in enumerate(results):
join_value = row[join_output_name]
join_index.setdefault(join_value, []).append(row_index)

return join_index


def _join_results(result_components_by_plan_id, join_indexes_by_plan_id,
current_results, join_descriptors):
"""Return the merged results across all subplans using the calculated join indexes."""
if len(join_descriptors) == 0:
# No further joining to be done!
return current_results

next_results = []

next_join_descriptor = join_descriptors[0]
remaining_join_descriptors = join_descriptors[1:]

join_plan_id = next_join_descriptor.child_query_plan.plan_id
join_index = join_indexes_by_plan_id[join_plan_id]
joining_results = result_components_by_plan_id[join_plan_id]
join_from_key, join_to_key = next_join_descriptor.output_names

for current_row in current_results:
join_value = current_row[join_from_key]

# To get inner join semantics, we don't output results that don't have matches.
# When we add support for stitching across @optional edges, we'll need to update this
# code to also output results even when the join index doesn't contain matches.
for join_matched_index in join_index.get(join_value, []):
joining_row = joining_results[join_matched_index]
next_results.append(dict(current_row, **joining_row))

return _join_results(result_components_by_plan_id, join_indexes_by_plan_id,
next_results, remaining_join_descriptors)


def _drop_intermediate_outputs(columns_to_drop, results):
"""Return the provided results with the specified column names dropped."""
processed_results = []

for row in results:
processed_results.append({
key: value
for key, value in row.items()
if key not in columns_to_drop
})

return processed_results
7 changes: 5 additions & 2 deletions graphql_compiler/schema_transformation/split_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, query_ast):
# List[SubQueryNode], the queries that depend on the current query


def split_query(query_ast, merged_schema_descriptor):
def split_query(query_ast, merged_schema_descriptor, strict=True):
"""Split input query AST into a tree of SubQueryNodes targeting each individual schema.

Property fields used in the stitch will be added if not already present. @output directives
Expand All @@ -62,6 +62,9 @@ def split_query(query_ast, merged_schema_descriptor):
schema: GraphQLSchema representing the merged schema
type_name_to_schema_id: Dict[str, str], mapping type names to
the id of the schema it came from
strict: bool, if set to True then limits query splitting to queries that are guaranteed
to be safely splittable. If False, then some queries may be permitted to be split
even though they are illegal. Use with caution.

Returns:
Tuple[SubQueryNode, frozenset[str]]. The first element is the root of the tree of
Expand All @@ -77,7 +80,7 @@ def split_query(query_ast, merged_schema_descriptor):
- SchemaStructureError if the input merged_schema_descriptor appears to be invalid
or inconsistent
"""
check_query_is_valid_to_split(merged_schema_descriptor.schema, query_ast)
check_query_is_valid_to_split(merged_schema_descriptor.schema, query_ast, strict=strict)

# If schema directives are correctly represented in the schema object, type_info is all
# that's needed to detect and address stitching fields. However, GraphQL currently ignores
Expand Down
53 changes: 50 additions & 3 deletions graphql_compiler/schema_transformation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
import six

from ..ast_manipulation import get_ast_with_non_null_and_list_stripped
from ..compiler.helpers import (
get_parameter_name, get_uniquely_named_objects_by_name, is_runtime_parameter
)
from ..exceptions import GraphQLError, GraphQLValidationError
from ..schema import FilterDirective, OptionalDirective, OutputDirective
from ..schema import FilterDirective, OptionalDirective, OutputDirective, RecurseDirective


class SchemaTransformError(GraphQLError):
Expand Down Expand Up @@ -397,8 +400,16 @@ class CheckQueryIsValidToSplitVisitor(Visitor):
OptionalDirective.name,
))

def __init__(self, strict):
"""Initialize the visitor with the appropriate strictness setting."""
super(CheckQueryIsValidToSplitVisitor, self).__init__()
self.strict = strict

def enter_Directive(self, node, *args):
"""Check that the directive is supported."""
if not self.strict:
return

if node.name.value not in self.supported_directives:
raise GraphQLValidationError(
u'Directive "{}" is not yet supported, only "{}" are currently '
Expand Down Expand Up @@ -443,7 +454,7 @@ def enter_SelectionSet(self, node, *args):
seen_vertex_field = True


def check_query_is_valid_to_split(schema, query_ast):
def check_query_is_valid_to_split(schema, query_ast, strict=True):
"""Check the query is valid for splitting.

In particular, ensure that the query validates against the schema, does not contain
Expand All @@ -453,6 +464,9 @@ def check_query_is_valid_to_split(schema, query_ast):
Args:
schema: GraphQLSchema object
query_ast: Document
strict: bool, if set to True then limits query splitting to queries that are guaranteed
to be safely splittable. If False, then some queries may be permitted to be split
even though they are illegal. Use with caution.

Raises:
GraphQLValidationError if the query doesn't validate against the schema, contains
Expand All @@ -466,5 +480,38 @@ def check_query_is_valid_to_split(schema, query_ast):
u'AST does not validate: {}'.format(built_in_validation_errors)
)
# Check no bad directives and fields are in order
visitor = CheckQueryIsValidToSplitVisitor()
visitor = CheckQueryIsValidToSplitVisitor(strict)
visit(query_ast, visitor)


class QueryRuntimeArgumentsVisitor(Visitor):
"""Visitor that collects runtime argument names from @filter directives."""

def __init__(self):
"""Initialize the visitor."""
super(QueryRuntimeArgumentsVisitor, self).__init__()
self.runtime_arguments = set()

def enter_Directive(self, node, *args):
"""Check that the directive is supported."""
if node.name.value != FilterDirective.name:
return

directive_arguments = get_uniquely_named_objects_by_name(node.arguments)
entry_names = [
list_element.value
for list_element in directive_arguments['value'].value.values
]

self.runtime_arguments.update(
get_parameter_name(name)
for name in entry_names
if is_runtime_parameter(name)
)


def get_query_runtime_arguments(query_ast):
"""Return a set containing the names of the runtime arguments required by the query."""
visitor = QueryRuntimeArgumentsVisitor()
visit(query_ast, visitor)
return visitor.runtime_arguments
Loading