Skip to content
This repository was archived by the owner on Feb 6, 2025. It is now read-only.

Commit c346068

Browse files
Add composite edge descriptor (#999)
* Add composite edge descriptor * Generalize some functions * Lint * Write error messages * Add test * Fix some docs
1 parent b276810 commit c346068

File tree

5 files changed

+118
-43
lines changed

5 files changed

+118
-43
lines changed

graphql_compiler/schema/schema_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..schema_generation.schema_graph import SchemaGraph
2121

2222

23-
@dataclass
23+
@dataclass(frozen=True)
2424
class DirectJoinDescriptor:
2525
"""Describes the ability to join two tables using the specified columns.
2626
@@ -35,7 +35,7 @@ class DirectJoinDescriptor:
3535
to_column: str # The column in the destination table we intend to join on.
3636

3737

38-
@dataclass
38+
@dataclass(frozen=True)
3939
class CompositeJoinDescriptor:
4040
"""Describes the ability to join two tables with a composite relationship.
4141

graphql_compiler/schema_generation/sqlalchemy/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
def get_sqlalchemy_schema_info(
9-
vertex_name_to_table, direct_edges, dialect, class_to_field_type_overrides=None
9+
vertex_name_to_table, edges, dialect, class_to_field_type_overrides=None
1010
):
1111
"""Return a SQLAlchemySchemaInfo from the metadata.
1212
@@ -31,14 +31,14 @@ def get_sqlalchemy_schema_info(
3131
The fields will have the same name as the underlying columns and
3232
columns with unsupported types, (SQL types with no matching GraphQL
3333
type), will be ignored.
34-
direct_edges: dict, str-> DirectEdgeDescriptor. The traversal of a direct
34+
edges: dict, str-> EdgeDescriptor. The traversal of an edge
3535
edge gets compiled to a SQL join in graphql_to_sql(). Therefore, each
36-
DirectEdgeDescriptor not only specifies the source and destination GraphQL
36+
EdgeDescriptor not only specifies the source and destination GraphQL
3737
objects, but also which columns to use to use when generating a SQL join
3838
between the underlying source and destination tables. The names of the edges
3939
are the keys in the dictionary and the edges will be rendered as vertex fields
4040
named out_<edgeName> and in_<edgeName> in the source and destination GraphQL
41-
objects respectively. The direct edge names must not conflict with the GraphQL
41+
objects respectively. The edge names must not conflict with the GraphQL
4242
object names.
4343
dialect: sqlalchemy.engine.interfaces.Dialect, specifying the dialect we are compiling to
4444
(e.g. sqlalchemy.dialects.mssql.dialect()).
@@ -50,15 +50,15 @@ def get_sqlalchemy_schema_info(
5050
Returns:
5151
SQLAlchemySchemaInfo containing the full information needed to compile SQL queries.
5252
"""
53-
schema_graph = get_sqlalchemy_schema_graph(vertex_name_to_table, direct_edges)
53+
schema_graph = get_sqlalchemy_schema_graph(vertex_name_to_table, edges)
5454

5555
graphql_schema, type_equivalence_hints = get_graphql_schema_from_schema_graph(
5656
schema_graph,
5757
class_to_field_type_overrides=class_to_field_type_overrides,
5858
hidden_classes=set(),
5959
)
6060

61-
join_descriptors = get_join_descriptors_from_edge_descriptors(direct_edges)
61+
join_descriptors = get_join_descriptors_from_edge_descriptors(edges)
6262

6363
return SQLAlchemySchemaInfo(
6464
graphql_schema, type_equivalence_hints, dialect, vertex_name_to_table, join_descriptors

graphql_compiler/schema_generation/sqlalchemy/edge_descriptors.py

Lines changed: 81 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,120 @@
11
# Copyright 2019-present Kensho Technologies, LLC.
2-
from typing import Dict, NamedTuple, Set
2+
from dataclasses import dataclass
3+
from typing import AbstractSet, Dict, Set, Tuple, Union
34
import warnings
45

56
import six
67
from sqlalchemy import Table
78

89
from ...schema import INBOUND_EDGE_FIELD_PREFIX, OUTBOUND_EDGE_FIELD_PREFIX
9-
from ...schema.schema_info import DirectJoinDescriptor
10+
from ...schema.schema_info import CompositeJoinDescriptor, DirectJoinDescriptor, JoinDescriptor
1011
from ..exceptions import InvalidSQLEdgeError
1112

1213

13-
class DirectEdgeDescriptor(NamedTuple):
14+
@dataclass(frozen=True)
15+
class DirectEdgeDescriptor:
16+
"""Represents a bidirectional edge between two vertices backed by DirectJoinDescriptors."""
17+
1418
from_vertex: str # Name of the source vertex.
1519
from_column: str # Name of the column of the underlying source table to use for SQL join.
1620
to_vertex: str # Name of the destination vertex.
1721
to_column: str # Name of the column of the underlying destination table to use for SQL join.
1822

1923

24+
@dataclass(frozen=True)
25+
class CompositeEdgeDescriptor:
26+
"""Represents a bidirectional edge between two vertices backed by CompositeJoinDescriptors."""
27+
28+
from_vertex: str # Name of the source vertex
29+
to_vertex: str # Name of the destination vertex
30+
31+
# (from_column, to_column) pairs, where from_column is on the origin table
32+
# and to_column is on the destination table of the join.
33+
column_pairs: AbstractSet[Tuple[str, str]]
34+
35+
def __post_init__(self) -> None:
36+
"""Validate fields."""
37+
if not self.column_pairs:
38+
raise AssertionError("The column_pairs field is expected to be non-empty.")
39+
40+
41+
EdgeDescriptor = Union[DirectEdgeDescriptor, CompositeEdgeDescriptor]
42+
43+
2044
def get_join_descriptors_from_edge_descriptors(
21-
direct_edges: Dict[str, DirectEdgeDescriptor]
22-
) -> Dict[str, Dict[str, DirectJoinDescriptor]]:
45+
direct_edges: Dict[str, EdgeDescriptor]
46+
) -> Dict[str, Dict[str, JoinDescriptor]]:
2347
"""Return the SQL edges in a format more suited to resolving vertex fields."""
24-
join_descriptors: Dict[str, Dict[str, DirectJoinDescriptor]] = {}
25-
for edge_name, direct_edge_descriptor in direct_edges.items():
26-
from_column = direct_edge_descriptor.from_column
27-
to_column = direct_edge_descriptor.to_column
28-
join_descriptors.setdefault(direct_edge_descriptor.from_vertex, {})
29-
join_descriptors.setdefault(direct_edge_descriptor.to_vertex, {})
48+
join_descriptors: Dict[str, Dict[str, JoinDescriptor]] = {}
49+
for edge_name, edge_descriptor in direct_edges.items():
50+
join_descriptors.setdefault(edge_descriptor.from_vertex, {})
51+
join_descriptors.setdefault(edge_descriptor.to_vertex, {})
3052
out_edge_name = OUTBOUND_EDGE_FIELD_PREFIX + edge_name
3153
in_edge_name = INBOUND_EDGE_FIELD_PREFIX + edge_name
32-
join_descriptors[direct_edge_descriptor.from_vertex][out_edge_name] = DirectJoinDescriptor(
33-
from_column, to_column
34-
)
35-
join_descriptors[direct_edge_descriptor.to_vertex][in_edge_name] = DirectJoinDescriptor(
36-
to_column, from_column
37-
)
54+
if isinstance(edge_descriptor, DirectEdgeDescriptor):
55+
from_column = edge_descriptor.from_column
56+
to_column = edge_descriptor.to_column
57+
join_descriptors[edge_descriptor.from_vertex][out_edge_name] = DirectJoinDescriptor(
58+
from_column, to_column
59+
)
60+
join_descriptors[edge_descriptor.to_vertex][in_edge_name] = DirectJoinDescriptor(
61+
to_column, from_column
62+
)
63+
elif isinstance(edge_descriptor, CompositeEdgeDescriptor):
64+
join_descriptors[edge_descriptor.from_vertex][out_edge_name] = CompositeJoinDescriptor(
65+
edge_descriptor.column_pairs
66+
)
67+
join_descriptors[edge_descriptor.to_vertex][in_edge_name] = CompositeJoinDescriptor(
68+
{
69+
(to_column, from_column)
70+
for from_column, to_column in edge_descriptor.column_pairs
71+
}
72+
)
73+
else:
74+
raise AssertionError(
75+
f"Unknown edge descriptor type {edge_descriptor}: "
76+
f"{type(edge_descriptor)} for edge {edge_name}."
77+
)
3878
return join_descriptors
3979

4080

4181
def validate_edge_descriptors(
42-
vertex_name_to_table: Dict[str, Table], direct_edges: Dict[str, DirectEdgeDescriptor]
82+
vertex_name_to_table: Dict[str, Table], edges: Dict[str, EdgeDescriptor]
4383
) -> None:
4484
"""Validate that the edge descriptors do not reference non-existent vertices or columns."""
4585
# TODO(pmantica1): Validate that columns in a direct SQL edge have comparable types.
4686
# TODO(pmantica1): Validate that columns don't have types that probably shouldn't be used for
4787
# joins, (e.g. array types).
48-
for edge_name, direct_edge_descriptor in six.iteritems(direct_edges):
49-
for vertex_name, column_name in (
50-
(direct_edge_descriptor.from_vertex, direct_edge_descriptor.from_column),
51-
(direct_edge_descriptor.to_vertex, direct_edge_descriptor.to_column),
52-
):
88+
for edge_name, edge_descriptor in six.iteritems(edges):
89+
if isinstance(edge_descriptor, DirectEdgeDescriptor):
90+
vertex_column_pairs = [
91+
(edge_descriptor.from_vertex, edge_descriptor.from_column),
92+
(edge_descriptor.to_vertex, edge_descriptor.to_column),
93+
]
94+
elif isinstance(edge_descriptor, CompositeEdgeDescriptor):
95+
vertex_column_pairs = [
96+
(edge_descriptor.from_vertex, from_column)
97+
for from_column, _ in edge_descriptor.column_pairs
98+
] + [
99+
(edge_descriptor.to_vertex, to_column)
100+
for _, to_column in edge_descriptor.column_pairs
101+
]
102+
else:
103+
raise AssertionError(
104+
f"Unknown edge descriptor type {edge_descriptor}: "
105+
f"{type(edge_descriptor)} for edge {edge_name}."
106+
)
107+
108+
for vertex_name, column_name in vertex_column_pairs:
53109
if vertex_name not in vertex_name_to_table:
54110
raise InvalidSQLEdgeError(
55111
"SQL edge {} with edge descriptor {} references a "
56-
"non-existent vertex {}".format(edge_name, direct_edge_descriptor, vertex_name)
112+
"non-existent vertex {}".format(edge_name, edge_descriptor, vertex_name)
57113
)
58114
if column_name not in vertex_name_to_table[vertex_name].columns:
59115
raise InvalidSQLEdgeError(
60116
"SQL edge {} with edge descriptor {} references a "
61-
"non-existent column {}".format(edge_name, direct_edge_descriptor, column_name)
117+
"non-existent column {}".format(edge_name, edge_descriptor, column_name)
62118
)
63119

64120

graphql_compiler/schema_generation/sqlalchemy/schema_graph_builder.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020

2121

22-
def get_sqlalchemy_schema_graph(vertex_name_to_table, direct_edges):
22+
def get_sqlalchemy_schema_graph(vertex_name_to_table, edges):
2323
"""Return a SchemaGraph from the metadata.
2424
2525
Args:
@@ -31,24 +31,24 @@ def get_sqlalchemy_schema_graph(vertex_name_to_table, direct_edges):
3131
will have the same name as the underlying columns and columns with
3232
unsupported types, (SQL types with no matching GraphQL type), will be
3333
ignored.
34-
direct_edges: dict, str -> DirectEdgeDescriptor. This dictionary will be used to generate
34+
edges: dict, str -> EdgeDescriptor. This dictionary will be used to generate
3535
EdgeType objects. The name of the EdgeType objects will be dictionary keys and
36-
the connections will be deduced from the DirectEdgeDescriptor objects.
36+
the connections will be deduced from the EdgeDescriptor objects.
3737
3838
Returns:
3939
SchemaGraph reflecting the specified metadata.
4040
"""
4141
validate_that_tables_belong_to_the_same_metadata_object(vertex_name_to_table.values())
42-
validate_edge_descriptors(vertex_name_to_table, direct_edges)
42+
validate_edge_descriptors(vertex_name_to_table, edges)
4343
validate_that_tables_have_primary_keys(vertex_name_to_table.values())
4444

4545
vertex_types = {
4646
vertex_name: _get_vertex_type_from_sqlalchemy_table(vertex_name, table)
4747
for vertex_name, table in vertex_name_to_table.items()
4848
}
4949
edge_types = {
50-
edge_name: _get_edge_type_from_direct_edge(edge_name, direct_edge_descriptor)
51-
for edge_name, direct_edge_descriptor in direct_edges.items()
50+
edge_name: _get_edge_type_from_edge(edge_name, edge_descriptor)
51+
for edge_name, edge_descriptor in edges.items()
5252
}
5353
elements = merge_non_overlapping_dicts(vertex_types, edge_types)
5454
elements.update(vertex_types)
@@ -72,15 +72,15 @@ def _get_vertex_type_from_sqlalchemy_table(vertex_name, table):
7272
return VertexType(vertex_name, False, properties, {})
7373

7474

75-
def _get_edge_type_from_direct_edge(edge_name, direct_edge_descriptor):
76-
"""Return the EdgeType corresponding to a direct SQL edge."""
75+
def _get_edge_type_from_edge(edge_name, edge_descriptor):
76+
"""Return the EdgeType corresponding to a SQL edge."""
7777
return EdgeType(
7878
edge_name,
7979
False,
8080
{},
8181
{},
82-
base_in_connection=direct_edge_descriptor.from_vertex,
83-
base_out_connection=direct_edge_descriptor.to_vertex,
82+
base_in_connection=edge_descriptor.from_vertex,
83+
base_out_connection=edge_descriptor.to_vertex,
8484
)
8585

8686

graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
get_join_descriptors_from_edge_descriptors,
2626
)
2727
from ...schema_generation.sqlalchemy.edge_descriptors import (
28+
CompositeEdgeDescriptor,
2829
DirectEdgeDescriptor,
2930
DirectJoinDescriptor,
3031
generate_direct_edge_descriptors_from_foreign_keys,
@@ -210,6 +211,24 @@ def test_index_generation_from_unique_constraint(self) -> None:
210211
indexes,
211212
)
212213

214+
def test_composite_edge(self) -> None:
215+
edges = {
216+
"composite_edge": CompositeEdgeDescriptor(
217+
"Table1",
218+
"TableWithMultiplePrimaryKeyColumns",
219+
{
220+
("source_column", "primary_key_column1"),
221+
("unique_column", "primary_key_column2"),
222+
},
223+
)
224+
}
225+
schema_info = get_sqlalchemy_schema_info(_get_test_vertex_name_to_table(), edges, dialect())
226+
self.assertTrue("out_composite_edge" in schema_info.join_descriptors["Table1"])
227+
self.assertTrue(
228+
"in_composite_edge"
229+
in schema_info.join_descriptors["TableWithMultiplePrimaryKeyColumns"]
230+
)
231+
213232

214233
@pytest.mark.filterwarnings("ignore: Ignored .* edges implied by composite foreign keys.*")
215234
class SQLAlchemyForeignKeyEdgeGenerationTests(unittest.TestCase):

0 commit comments

Comments
 (0)