Skip to content
This repository was archived by the owner on Feb 6, 2025. It is now read-only.

Commit b276810

Browse files
Add edges defined by a composite join clause (#998)
* Add composite join descriptor * Add TODOs * Construct general on clause * Deal with _came_from columns * Add errors * Update _find_used_columns * Add test with traversal * Lint * Lint * Add todos * Lint * Add error messages * Add error message * Address coment * Update graphql_compiler/compiler/emit_sql.py Co-authored-by: Predrag Gruevski <[email protected]> * Address comments * Lint Co-authored-by: Predrag Gruevski <[email protected]>
1 parent c645238 commit b276810

File tree

8 files changed

+266
-70
lines changed

8 files changed

+266
-70
lines changed

graphql_compiler/compiler/emit_sql.py

Lines changed: 71 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright 2018-present Kensho Technologies, LLC.
22
"""Transform a SqlNode tree into an executable SQLAlchemy query."""
33
from dataclasses import dataclass
4-
from typing import Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Union
4+
from typing import AbstractSet, Dict, Iterator, List, NamedTuple, Optional, Set, Tuple, Union
55

66
import six
77
import sqlalchemy
@@ -21,7 +21,12 @@
2121
from . import blocks
2222
from ..global_utils import VertexPath
2323
from ..schema import COUNT_META_FIELD_NAME
24-
from ..schema.schema_info import DirectJoinDescriptor, SQLAlchemySchemaInfo
24+
from ..schema.schema_info import (
25+
CompositeJoinDescriptor,
26+
DirectJoinDescriptor,
27+
JoinDescriptor,
28+
SQLAlchemySchemaInfo,
29+
)
2530
from .compiler_entities import BasicBlock
2631
from .compiler_frontend import IrAndMetadata
2732
from .expressions import ContextField, Expression
@@ -143,21 +148,30 @@ def _find_used_columns(
143148
)
144149
vertex_field_name = f"{edge_direction}_{edge_name}"
145150
edge = sql_schema_info.join_descriptors[location_info.type.name][vertex_field_name]
146-
used_columns.setdefault(get_vertex_path(location), set()).add(edge.from_column)
147-
used_columns.setdefault(get_vertex_path(child_location), set()).add(edge.to_column)
151+
if isinstance(edge, DirectJoinDescriptor):
152+
columns_at_location = {edge.from_column}
153+
columns_at_child = {edge.to_column}
154+
elif isinstance(edge, CompositeJoinDescriptor):
155+
columns_at_location = {column_pair[0] for column_pair in edge.column_pairs}
156+
columns_at_child = {column_pair[1] for column_pair in edge.column_pairs}
157+
else:
158+
raise AssertionError(f"Unknown join descriptor type {edge}: {type(edge)}")
159+
160+
used_columns.setdefault(get_vertex_path(location), set()).update(columns_at_location)
161+
used_columns.setdefault(get_vertex_path(child_location), set()).update(columns_at_child)
148162

149163
# Check if the edge is recursive
150164
child_location_info = ir.query_metadata_table.get_location_info(child_location)
151165
if child_location_info.recursive_scopes_depth > location_info.recursive_scopes_depth:
152166
# The primary key may be used if the recursive cte base semijoins to
153167
# the pre-recurse cte by primary key.
154168
alias = sql_schema_info.vertex_name_to_table[location_info.type.name].alias()
155-
primary_key_name = _get_primary_key_name(alias, location_info.type.name, "@recurse")
156-
used_columns.setdefault(get_vertex_path(location), set()).add(primary_key_name)
169+
primary_keys = {column.name for column in alias.primary_key}
170+
used_columns.setdefault(get_vertex_path(location), set()).update(primary_keys)
157171

158172
# The from_column is used at the destination as well, inside the recursive step
159-
used_columns.setdefault(get_vertex_path(child_location), set()).add(
160-
edge.from_column
173+
used_columns.setdefault(get_vertex_path(child_location), set()).update(
174+
columns_at_location
161175
)
162176

163177
# Find outputs used
@@ -780,7 +794,9 @@ def __init__(self, sql_schema_info: SQLAlchemySchemaInfo, ir: IrAndMetadata):
780794
# Move to the beginning location of the query.
781795
self._relocate(ir.query_metadata_table.root_location)
782796

783-
# Mapping aliases to the column used to join into them.
797+
# Mapping aliases to one of the column used to join into them. We use this column
798+
# to check for LEFT JOIN misses, since it helps us distinguish actuall NULL values
799+
# from values that are NULL because of a LEFT JOIN miss.
784800
self._came_from: Dict[Union[Alias, ColumnRouter], Column] = {}
785801

786802
self._recurse_needs_cte: bool = False
@@ -840,9 +856,8 @@ def _relocate(self, new_location: BaseLocation):
840856
self._current_alias, self._current_location, output_fields
841857
)
842858

843-
# TODO merge from_column and to_column into a joindescriptor
844859
def _join_to_parent_location(
845-
self, parent_alias: Alias, from_column: str, to_column: str, optional: bool
860+
self, parent_alias: Alias, join_descriptor: JoinDescriptor, optional: bool
846861
):
847862
"""Join the current location to the parent location using the column names specified."""
848863
if self._current_alias is None:
@@ -851,7 +866,25 @@ def _join_to_parent_location(
851866
f"during fold {self}."
852867
)
853868

854-
self._came_from[self._current_alias] = self._current_alias.c[to_column]
869+
# construct on clause for join
870+
if isinstance(join_descriptor, DirectJoinDescriptor):
871+
matching_column_pairs: AbstractSet[Tuple[str, str]] = {
872+
(join_descriptor.from_column, join_descriptor.to_column),
873+
}
874+
elif isinstance(join_descriptor, CompositeJoinDescriptor):
875+
matching_column_pairs = join_descriptor.column_pairs
876+
else:
877+
raise AssertionError(
878+
f"Unknown join descriptor type {join_descriptor}: {type(join_descriptor)}"
879+
)
880+
881+
if not matching_column_pairs:
882+
raise AssertionError(
883+
f"Invalid join descriptor {join_descriptor}, produced no matching column pairs."
884+
)
885+
886+
_, non_null_column = sorted(matching_column_pairs)[0]
887+
self._came_from[self._current_alias] = self._current_alias.c[non_null_column]
855888

856889
if self._is_in_optional_scope() and not optional:
857890
# For mandatory edges in optional scope, we emit LEFT OUTER JOIN and enforce the
@@ -879,10 +912,17 @@ def _join_to_parent_location(
879912
)
880913
)
881914

915+
on_clause = sqlalchemy.and_(
916+
*(
917+
parent_alias.c[from_column] == self._current_alias.c[to_column]
918+
for from_column, to_column in sorted(matching_column_pairs)
919+
)
920+
)
921+
882922
# Join to where we came from.
883923
self._from_clause = self._from_clause.join(
884924
self._current_alias,
885-
onclause=(parent_alias.c[from_column] == self._current_alias.c[to_column]),
925+
onclause=on_clause,
886926
isouter=self._is_in_optional_scope(),
887927
)
888928

@@ -932,11 +972,14 @@ def traverse(self, vertex_field: str, optional: bool) -> None:
932972
"Attempting to traverse inside a fold while the _current_location was not a "
933973
f"FoldScopeLocation. _current_location was set to {self._current_location}."
934974
)
975+
if not isinstance(edge, DirectJoinDescriptor):
976+
raise NotImplementedError(
977+
f"Edge {vertex_field} is backed by a CompositeJoinDescriptor, "
978+
"so it can't be used inside a @fold scope."
979+
)
935980
self._current_fold.add_traversal(edge, previous_alias, self._current_alias)
936981
else:
937-
self._join_to_parent_location(
938-
previous_alias, edge.from_column, edge.to_column, optional
939-
)
982+
self._join_to_parent_location(previous_alias, edge, optional)
940983

941984
def _wrap_into_cte(self) -> None:
942985
"""Wrap the current query into a cte."""
@@ -1017,6 +1060,11 @@ def recurse(self, vertex_field: str, depth: int) -> None:
10171060
)
10181061

10191062
edge = self._sql_schema_info.join_descriptors[self._current_classname][vertex_field]
1063+
if not isinstance(edge, DirectJoinDescriptor):
1064+
raise NotImplementedError(
1065+
f"Edge {vertex_field} requires a JOIN across a composite key, which is currently "
1066+
f"not supported for use with @recurse."
1067+
)
10201068
primary_key = self._get_current_primary_key_name("@recurse")
10211069

10221070
# Wrap the query so far into a CTE if it would speed up the recursive query.
@@ -1074,8 +1122,8 @@ def recurse(self, vertex_field: str, depth: int) -> None:
10741122
.where(base.c[CTE_DEPTH_NAME] < literal_depth)
10751123
)
10761124

1077-
# TODO(bojanserafimov): This creates an unused alias if there's no tags or outputs so far
1078-
self._join_to_parent_location(previous_alias, primary_key, CTE_KEY_NAME, False)
1125+
join_descriptor = DirectJoinDescriptor(primary_key, CTE_KEY_NAME)
1126+
self._join_to_parent_location(previous_alias, join_descriptor, False)
10791127

10801128
def start_global_operations(self) -> None:
10811129
"""Execute a GlobalOperationsStart block."""
@@ -1131,6 +1179,11 @@ def fold(self, fold_scope_location: FoldScopeLocation) -> None:
11311179
join_descriptor = self._sql_schema_info.join_descriptors[self._current_classname][
11321180
full_edge_name
11331181
]
1182+
if not isinstance(join_descriptor, DirectJoinDescriptor):
1183+
raise NotImplementedError(
1184+
f"Edge {full_edge_name} requires a JOIN across a composite key, which is currently "
1185+
"not supported for use with @fold."
1186+
)
11341187

11351188
# 3. Initialize fold object.
11361189
self._current_fold = FoldSubqueryBuilder(

graphql_compiler/compiler/ir_lowering_sql/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from .. import blocks, expressions
55
from ...compiler.compiler_frontend import IrAndMetadata
6+
from ...schema.schema_info import CompositeJoinDescriptor, DirectJoinDescriptor
67
from ..helpers import FoldScopeLocation, get_edge_direction_and_name
78
from ..ir_lowering_common import common
89

@@ -40,9 +41,14 @@ def _find_non_null_columns(schema_info, query_metadata_table):
4041
vertex_field_name = "{}_{}".format(edge_direction, edge_name)
4142
edge = schema_info.join_descriptors[location_info.type.name][vertex_field_name]
4243

43-
# The value of the column used to join to this table is an indicator of whether
44+
# The value of any column used to join to this table is an indicator of whether
4445
# the left join was a hit or a miss.
45-
non_null_column[child_location.query_path] = edge.to_column
46+
if isinstance(edge, DirectJoinDescriptor):
47+
non_null_column[child_location.query_path] = edge.to_column
48+
elif isinstance(edge, CompositeJoinDescriptor):
49+
non_null_column[child_location.query_path] = sorted(edge.column_pairs)[0][1]
50+
else:
51+
raise AssertionError(f"Unknown join descriptor type {edge}: {type(edge)}")
4652

4753
return non_null_column
4854

graphql_compiler/schema/schema_info.py

Lines changed: 47 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from dataclasses import dataclass, field
55
from enum import Enum, Flag, auto, unique
66
from functools import partial
7-
from typing import Dict, Mapping, Optional, Sequence
7+
from typing import AbstractSet, Dict, Mapping, Optional, Sequence, Tuple, Union
88

99
from graphql.type import GraphQLSchema
1010
from graphql.type.definition import GraphQLInterfaceType, GraphQLObjectType
@@ -20,20 +20,46 @@
2020
from ..schema_generation.schema_graph import SchemaGraph
2121

2222

23-
# Describes the intent to join two tables using the specified columns.
24-
#
25-
# The resulting join expression could be something like:
26-
# JOIN origin_table.from_column = destination_table.to_column
27-
#
28-
# The type of join (inner vs left, etc.) is not specified.
29-
# The tables are not specified.
30-
DirectJoinDescriptor = namedtuple(
31-
"DirectJoinDescriptor",
32-
(
33-
"from_column", # The column in the source table we intend to join on.
34-
"to_column", # The column in the destination table we intend to join on.
35-
),
36-
)
23+
@dataclass
24+
class DirectJoinDescriptor:
25+
"""Describes the ability to join two tables using the specified columns.
26+
27+
The resulting join expression could be something like:
28+
JOIN origin_table.from_column = destination_table.to_column
29+
30+
The type of join (inner vs left, etc.) is not specified.
31+
The tables are not specified.
32+
"""
33+
34+
from_column: str # The column in the source table we intend to join on.
35+
to_column: str # The column in the destination table we intend to join on.
36+
37+
38+
@dataclass
39+
class CompositeJoinDescriptor:
40+
"""Describes the ability to join two tables with a composite relationship.
41+
42+
The resulting join expression could be something like:
43+
JOIN
44+
origin_table.from_column_1 == destination_table.to_column_1 AND
45+
origin_table.from_column_2 == destination_table.to_column_2 AND
46+
origin_table.from_column_3 == destination_table.to_column_3
47+
48+
The type of join (inner vs left, etc.) is not specified.
49+
The tables are not specified.
50+
"""
51+
52+
# (from_column, to_column) pairs, where from_column is on the origin table
53+
# and to_column is on the destination table of the join.
54+
column_pairs: AbstractSet[Tuple[str, str]]
55+
56+
def __post_init__(self) -> None:
57+
"""Validate fields."""
58+
if not self.column_pairs:
59+
raise AssertionError("The column_pairs field is expected to be non-empty.")
60+
61+
62+
JoinDescriptor = Union[DirectJoinDescriptor, CompositeJoinDescriptor]
3763

3864

3965
@dataclass
@@ -143,17 +169,17 @@ class SQLSchemaInfo(BackendSpecificSchemaInfo):
143169

144170
vertex_name_to_table: Dict[str, sqlalchemy.Table]
145171
# dict mapping every GraphQL object type or interface type name in the schema to
146-
# dict mapping every vertex field name at that type to a DirectJoinDescriptor.
172+
# dict mapping every vertex field name at that type to a JoinDescriptor.
147173
# The tables the join is to be performed on are not specified.
148174
# They are inferred from the schema and the tables dictionary.
149-
join_descriptors: Dict[str, Dict[str, DirectJoinDescriptor]]
175+
join_descriptors: Dict[str, Dict[str, JoinDescriptor]]
150176

151177

152178
def _create_sql_schema_info(
153179
dialect: Dialect,
154180
schema: GraphQLSchema,
155181
vertex_name_to_table: Dict[str, sqlalchemy.Table],
156-
join_descriptors: Dict[str, Dict[str, DirectJoinDescriptor]],
182+
join_descriptors: Dict[str, Dict[str, JoinDescriptor]],
157183
type_equivalence_hints: Optional[Dict[str, str]] = None,
158184
) -> SQLSchemaInfo:
159185
"""Create a SQLSchemaInfo object for a database using a flavor of SQL."""
@@ -239,7 +265,7 @@ def _create_sql_schema_info(
239265
# All tables are expected to have primary keys.
240266
"vertex_name_to_table",
241267
# dict mapping every graphql object type or interface type name in the schema to:
242-
# dict mapping every vertex field name at that type to a DirectJoinDescriptor. The
268+
# dict mapping every vertex field name at that type to a JoinDescriptor. The
243269
# tables the join is to be performed on are not specified. They are inferred from
244270
# the schema and the tables dictionary.
245271
"join_descriptors",
@@ -252,7 +278,7 @@ def make_sqlalchemy_schema_info(
252278
type_equivalence_hints: TypeEquivalenceHintsType,
253279
dialect: Dialect,
254280
vertex_name_to_table: Dict[str, sqlalchemy.Table],
255-
join_descriptors: Dict[str, Dict[str, DirectJoinDescriptor]],
281+
join_descriptors: Dict[str, Dict[str, JoinDescriptor]],
256282
validate: bool = True,
257283
) -> SQLAlchemySchemaInfo:
258284
"""Make a SQLAlchemySchemaInfo if the input provided is valid.
@@ -282,7 +308,7 @@ def make_sqlalchemy_schema_info(
282308
schema to a SQLAlchemy table
283309
join_descriptors: dict mapping GraphQL object and interface type names in the schema to:
284310
dict mapping every vertex field name at that type to a
285-
DirectJoinDescriptor. The tables on which the join is to be performed
311+
JoinDescriptor. The tables on which the join is to be performed
286312
are not specified. They are inferred from the schema and the tables
287313
dictionary.
288314
validate: whether to validate that the given inputs are valid for creation of

graphql_compiler/schema_generation/sqlalchemy/edge_descriptors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def generate_direct_edge_descriptors_from_foreign_keys(
115115
else:
116116
number_of_composite_foreign_keys += 1
117117

118+
# TODO(bojanserafimov): Infer CompositeJoinDescriptor objects
118119
if number_of_composite_foreign_keys:
119120
warnings.warn(
120121
"Ignored {} edges implied by composite foreign keys. We currently do not "

graphql_compiler/tests/test_compiler.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11154,3 +11154,41 @@ def test_recursive_field_type_is_subtype_of_parent_field(self) -> None:
1115411154
expected_cypher,
1115511155
expected_sql,
1115611156
)
11157+
11158+
def test_animal_born_at_traversal(self) -> None:
11159+
"""Ensure that sql composite key traversals work."""
11160+
test_data = test_input_data.animal_born_at_traversal()
11161+
11162+
expected_match = SKIP_TEST
11163+
expected_gremlin = SKIP_TEST
11164+
expected_mssql = """
11165+
SELECT
11166+
[Animal_1].name AS animal_name,
11167+
[BirthEvent_1].name AS birth_event_name
11168+
FROM
11169+
db_1.schema_1.[Animal] AS [Animal_1]
11170+
JOIN db_1.schema_1.[BirthEvent] AS [BirthEvent_1]
11171+
ON [Animal_1].birth_date = [BirthEvent_1].event_date
11172+
AND [Animal_1].birth_uuid = [BirthEvent_1].uuid
11173+
"""
11174+
expected_postgresql = """
11175+
SELECT
11176+
"Animal_1".name AS animal_name,
11177+
"BirthEvent_1".name AS birth_event_name
11178+
FROM
11179+
schema_1."Animal" AS "Animal_1"
11180+
JOIN schema_1."BirthEvent" AS "BirthEvent_1"
11181+
ON "Animal_1".birth_date = "BirthEvent_1".event_date
11182+
AND "Animal_1".birth_uuid = "BirthEvent_1".uuid
11183+
"""
11184+
expected_cypher = SKIP_TEST
11185+
11186+
check_test_data(
11187+
self,
11188+
test_data,
11189+
expected_match,
11190+
expected_gremlin,
11191+
expected_mssql,
11192+
expected_cypher,
11193+
expected_postgresql,
11194+
)

0 commit comments

Comments
 (0)