Skip to content
This repository was archived by the owner on Feb 6, 2025. It is now read-only.

Commit 1065f2c

Browse files
chewseleneSelene Chewbojanserafimov
authored
Convert SQLAlchemySchemaInfo to a dataclass (#1010)
* Convert SQLAlchemySchemaInfo to a dataclass * fix docstring formatting Co-authored-by: bojanserafimov <[email protected]> * raise not implemented error for composite join inside SQL folds * clarify error messages Co-authored-by: Selene Chew <[email protected]> Co-authored-by: bojanserafimov <[email protected]>
1 parent 44c63aa commit 1065f2c

File tree

3 files changed

+121
-59
lines changed

3 files changed

+121
-59
lines changed

graphql_compiler/compiler/emit_sql.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def _get_fold_outputs(self) -> List[Label]:
632632

633633
def add_traversal(
634634
self,
635-
join_descriptor: DirectJoinDescriptor,
635+
join_descriptor: JoinDescriptor,
636636
from_table: Alias,
637637
to_table: Alias,
638638
) -> None:
@@ -643,18 +643,32 @@ def add_traversal(
643643
f"Invalid state encountered during fold {self}."
644644
)
645645

646-
# Ensure that the previous traversals's from_table matches the next traversal's to_table.
647-
if len(self._traversal_descriptors) > 0:
648-
if self._traversal_descriptors[-1].to_table != from_table:
649-
raise AssertionError(
650-
"Received invalid traversal. The previous traversal's to_table "
651-
"should match the next traversal's from_table. Previous to_table was "
652-
f"{self._traversal_descriptors[-1].to_table.description} while the current "
653-
f"from_table was {from_table.description}."
654-
)
655-
self._traversal_descriptors.append(
656-
SQLFoldTraversalDescriptor(join_descriptor, from_table, to_table)
657-
)
646+
# The edge name is not available within the function making the error message
647+
# user-unfriendly. A similar check should be performed prior to calling this function to
648+
# ensure that a user-friendly message is presented containing the composite join-backed
649+
# edge name, but this check is also necessary in order to appease mypy.
650+
if isinstance(join_descriptor, CompositeJoinDescriptor):
651+
raise NotImplementedError(
652+
"Composite joins are not implemented inside of folds for SQL."
653+
)
654+
elif isinstance(join_descriptor, DirectJoinDescriptor):
655+
# Ensure that the previous traversals's from_table matches the next traversal's
656+
# to_table.
657+
if len(self._traversal_descriptors) > 0:
658+
if self._traversal_descriptors[-1].to_table != from_table:
659+
raise AssertionError(
660+
"Received invalid traversal. The previous traversal's to_table "
661+
"should match the next traversal's from_table. Previous to_table was "
662+
f"{self._traversal_descriptors[-1].to_table.description} while the current "
663+
f"from_table was {from_table.description}."
664+
)
665+
self._traversal_descriptors.append(
666+
SQLFoldTraversalDescriptor(join_descriptor, from_table, to_table)
667+
)
668+
else:
669+
raise AssertionError(
670+
f"Unreachable code reached! Unknown JoinDescriptor {type(join_descriptor)}."
671+
)
658672

659673
def mark_output_location_and_fields(
660674
self, output_table: Alias, output_table_location: FoldScopeLocation, output_fields: Set[str]
@@ -972,6 +986,10 @@ def traverse(self, vertex_field: str, optional: bool) -> None:
972986
"Attempting to traverse inside a fold while the _current_location was not a "
973987
f"FoldScopeLocation. _current_location was set to {self._current_location}."
974988
)
989+
# add_traversal performs the same check internally, but checking here is necessary in
990+
# order to give the user a better error message - the vertex_field is not available from
991+
# within add_traversal and therefore cannot be put in a user-friendly error message
992+
# within the function.
975993
if not isinstance(edge, DirectJoinDescriptor):
976994
raise NotImplementedError(
977995
f"Edge {vertex_field} is backed by a CompositeJoinDescriptor, "
@@ -1179,6 +1197,10 @@ def fold(self, fold_scope_location: FoldScopeLocation) -> None:
11791197
join_descriptor = self._sql_schema_info.join_descriptors[self._current_classname][
11801198
full_edge_name
11811199
]
1200+
# add_traversal performs the same check internally, but checking here is necessary in
1201+
# order to give the user a better error message - the full_edge_name is not available from
1202+
# within add_traversal and therefore cannot be put in a user-friendly error message
1203+
# within the function.
11821204
if not isinstance(join_descriptor, DirectJoinDescriptor):
11831205
raise NotImplementedError(
11841206
f"Edge {full_edge_name} requires a JOIN across a composite key, which is currently "

graphql_compiler/schema/schema_info.py

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -226,56 +226,63 @@ def _create_sql_schema_info(
226226
)
227227

228228

229-
# Complete schema information sufficient to compile GraphQL queries to SQLAlchemy
230-
#
231-
# It describes the tables that correspond to each type (object type or interface type),
232-
# and gives instructions on how to perform joins for each vertex field. The property fields on each
233-
# type are implicitly mapped to columns with the same name on the corresponding table.
234-
#
235-
# NOTES:
236-
# - RootSchemaQuery is a special type that does not need a corresponding table.
237-
# - Builtin types like __Schema, __Type, etc. don't need corresponding tables.
238-
# - Builtin fields like _x_count do not need corresponding columns.
239-
SQLAlchemySchemaInfo = namedtuple(
240-
"SQLAlchemySchemaInfo",
241-
(
242-
# GraphQLSchema
243-
"schema",
244-
# optional dict of GraphQL interface or type -> GraphQL union.
245-
# Used as a workaround for GraphQL's lack of support for
246-
# inheritance across "types" (i.e. non-interfaces), as well as a
247-
# workaround for Gremlin's total lack of inheritance-awareness.
248-
# The key-value pairs in the dict specify that the "key" type
249-
# is equivalent to the "value" type, i.e. that the GraphQL type or
250-
# interface in the key is the most-derived common supertype
251-
# of every GraphQL type in the "value" GraphQL union.
252-
# Recursive expansion of type equivalence hints is not performed,
253-
# and only type-level correctness of this argument is enforced.
254-
# See README.md for more details on everything this parameter does.
255-
# *****
256-
# Be very careful with this option, as bad input here will
257-
# lead to incorrect output queries being generated.
258-
# *****
259-
"type_equivalence_hints",
260-
# sqlalchemy.engine.interfaces.Dialect, specifying the dialect we are compiling for
261-
# (e.g. sqlalchemy.dialects.mssql.dialect()).
262-
"dialect",
263-
# dict mapping every graphql object type or interface type name in the schema to
264-
# a sqlalchemy table. Column types that do not exist for this dialect are not allowed.
265-
# All tables are expected to have primary keys.
266-
"vertex_name_to_table",
267-
# dict mapping every graphql object type or interface type name in the schema to:
268-
# dict mapping every vertex field name at that type to a JoinDescriptor. The
269-
# tables the join is to be performed on are not specified. They are inferred from
270-
# the schema and the tables dictionary.
271-
"join_descriptors",
272-
),
273-
)
229+
@dataclass
230+
class SQLAlchemySchemaInfo:
231+
"""Complete schema information sufficient to compile GraphQL queries to SQLAlchemy.
232+
233+
It describes the tables that correspond to each type (object type or interface type),
234+
and gives instructions on how to perform joins for each vertex field. The property fields on
235+
each type are implicitly mapped to columns with the same name on the corresponding table.
236+
237+
Notes:
238+
- RootSchemaQuery is a special type that does not need a corresponding table.
239+
- Builtin types like __Schema, __Type, etc. don't need corresponding tables.
240+
- Builtin fields like _x_count do not need corresponding columns.
241+
242+
TODO: This class is essentially the same as SQLSchemaInfo. SQLSchemaInfo is part of an
243+
incomplete refactor started in
244+
https://github.com/kensho-technologies/graphql-compiler/pull/714
245+
SQLAlchemySchemaInfo is currently used to compile GraphQL to SQL while CommonSchemaInfo
246+
is currently used to compile GraphQL to match, gremlin, and cypher.
247+
"""
248+
249+
schema: GraphQLSchema
250+
251+
# Optional dict of GraphQL interface or type -> GraphQL union.
252+
# Used as a workaround for GraphQL's lack of support for
253+
# inheritance across "types" (i.e. non-interfaces), as well as a
254+
# workaround for Gremlin's total lack of inheritance-awareness.
255+
# The key-value pairs in the dict specify that the "key" type
256+
# is equivalent to the "value" type, i.e. that the GraphQL type or
257+
# interface in the key is the most-derived common supertype
258+
# of every GraphQL type in the "value" GraphQL union.
259+
# Recursive expansion of type equivalence hints is not performed,
260+
# and only type-level correctness of this argument is enforced.
261+
# See README.md for more details on everything this parameter does.
262+
# *****
263+
# Be very careful with this option, as bad input here will
264+
# lead to incorrect output queries being generated.
265+
# *****
266+
type_equivalence_hints: Optional[TypeEquivalenceHintsType]
267+
268+
# Specifying the SQL Dialect.
269+
dialect: Dialect
270+
271+
# Mapping every GraphQL object or interface type name in the schema to the corresponding
272+
# SQLAlchemy table. Column types that do not exist for this dialect are not allowed.
273+
# All tables are expected to have primary keys.
274+
vertex_name_to_table: Dict[str, sqlalchemy.Table]
275+
276+
# Mapping every GraphQL object or interface type name in the schema to:
277+
# dict mapping every vertex field name at that type to a JoinDescriptor. The
278+
# tables the join is to be performed on are not specified. They are inferred from
279+
# the schema and the tables dictionary.
280+
join_descriptors: Dict[str, Dict[str, JoinDescriptor]]
274281

275282

276283
def make_sqlalchemy_schema_info(
277284
schema: GraphQLSchema,
278-
type_equivalence_hints: TypeEquivalenceHintsType,
285+
type_equivalence_hints: Optional[TypeEquivalenceHintsType],
279286
dialect: Dialect,
280287
vertex_name_to_table: Dict[str, sqlalchemy.Table],
281288
join_descriptors: Dict[str, Dict[str, JoinDescriptor]],

graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,26 @@ def test_table_vertex_representation_with_non_default_name(self) -> None:
112112

113113
def test_represent_supported_fields(self) -> None:
114114
table1_graphql_object = self.schema_info.schema.get_type("Table1")
115+
# mypy complained even with self.assertIsInstance(table1_graphql_object, GraphQLObjectType)
116+
# so performing a manual check.
117+
if not isinstance(table1_graphql_object, GraphQLObjectType):
118+
raise AssertionError(
119+
f"table1_graphql_object expected to be GraphQLObjectType, but was of type "
120+
f"{type(table1_graphql_object)}"
121+
)
115122
self.assertEqual(
116123
table1_graphql_object.fields["column_with_supported_type"].type, GraphQLString
117124
)
118125

119126
def test_ignored_fields_not_supported(self) -> None:
120127
table1_graphql_object = self.schema_info.schema.get_type("Table1")
128+
# mypy complained even with self.assertIsInstance(table1_graphql_object, GraphQLObjectType)
129+
# so performing a manual check.
130+
if not isinstance(table1_graphql_object, GraphQLObjectType):
131+
raise AssertionError(
132+
f"table1_graphql_object expected to be GraphQLObjectType, but was of type "
133+
f"{type(table1_graphql_object)}"
134+
)
121135
self.assertTrue("column_with_non_supported_type" not in table1_graphql_object.fields)
122136

123137
def test_warn_when_type_is_not_supported(self) -> None:
@@ -140,11 +154,30 @@ def test_do_not_support_sql_tz_aware_datetime_types(self) -> None:
140154

141155
def test_mssql_scalar_type_representation(self) -> None:
142156
table1_graphql_object = self.schema_info.schema.get_type("Table1")
157+
# mypy complained even with self.assertIsInstance(table1_graphql_object, GraphQLObjectType)
158+
# so performing a manual check.
159+
if not isinstance(table1_graphql_object, GraphQLObjectType):
160+
raise AssertionError(
161+
f"table1_graphql_object expected to be GraphQLObjectType, but was of type "
162+
f"{type(table1_graphql_object)}"
163+
)
143164
self.assertEqual(table1_graphql_object.fields["column_with_mssql_type"].type, GraphQLInt)
144165

145166
def test_direct_sql_edge_representation(self) -> None:
146167
table1_graphql_object = self.schema_info.schema.get_type("Table1")
147168
arbitrarily_named_graphql_object = self.schema_info.schema.get_type("ArbitraryObjectName")
169+
# mypy complained even with self.assertIsInstance(table1_graphql_object, GraphQLObjectType)
170+
# so performing a manual check.
171+
if not isinstance(table1_graphql_object, GraphQLObjectType):
172+
raise AssertionError(
173+
f"table1_graphql_object expected to be GraphQLObjectType, but was of type "
174+
f"{type(table1_graphql_object)}"
175+
)
176+
if not isinstance(arbitrarily_named_graphql_object, GraphQLObjectType):
177+
raise AssertionError(
178+
f"arbitrarily_named_graphql_object expected to be GraphQLObjectType, but was of "
179+
f"type {type(arbitrarily_named_graphql_object)}"
180+
)
148181
self.assertEqual(
149182
table1_graphql_object.fields["out_test_edge"].type.of_type.name, "ArbitraryObjectName"
150183
)

0 commit comments

Comments
 (0)