diff --git a/graphql_compiler/compiler/emit_sql.py b/graphql_compiler/compiler/emit_sql.py index 26f1b41cd..1a706b80b 100644 --- a/graphql_compiler/compiler/emit_sql.py +++ b/graphql_compiler/compiler/emit_sql.py @@ -632,7 +632,7 @@ def _get_fold_outputs(self) -> List[Label]: def add_traversal( self, - join_descriptor: DirectJoinDescriptor, + join_descriptor: JoinDescriptor, from_table: Alias, to_table: Alias, ) -> None: @@ -643,18 +643,32 @@ def add_traversal( f"Invalid state encountered during fold {self}." ) - # Ensure that the previous traversals's from_table matches the next traversal's to_table. - if len(self._traversal_descriptors) > 0: - if self._traversal_descriptors[-1].to_table != from_table: - raise AssertionError( - "Received invalid traversal. The previous traversal's to_table " - "should match the next traversal's from_table. Previous to_table was " - f"{self._traversal_descriptors[-1].to_table.description} while the current " - f"from_table was {from_table.description}." - ) - self._traversal_descriptors.append( - SQLFoldTraversalDescriptor(join_descriptor, from_table, to_table) - ) + # The edge name is not available within the function making the error message + # user-unfriendly. A similar check should be performed prior to calling this function to + # ensure that a user-friendly message is presented containing the composite join-backed + # edge name, but this check is also necessary in order to appease mypy. + if isinstance(join_descriptor, CompositeJoinDescriptor): + raise NotImplementedError( + "Composite joins are not implemented inside of folds for SQL." + ) + elif isinstance(join_descriptor, DirectJoinDescriptor): + # Ensure that the previous traversals's from_table matches the next traversal's + # to_table. + if len(self._traversal_descriptors) > 0: + if self._traversal_descriptors[-1].to_table != from_table: + raise AssertionError( + "Received invalid traversal. The previous traversal's to_table " + "should match the next traversal's from_table. Previous to_table was " + f"{self._traversal_descriptors[-1].to_table.description} while the current " + f"from_table was {from_table.description}." + ) + self._traversal_descriptors.append( + SQLFoldTraversalDescriptor(join_descriptor, from_table, to_table) + ) + else: + raise AssertionError( + f"Unreachable code reached! Unknown JoinDescriptor {type(join_descriptor)}." + ) def mark_output_location_and_fields( self, output_table: Alias, output_table_location: FoldScopeLocation, output_fields: Set[str] @@ -972,6 +986,10 @@ def traverse(self, vertex_field: str, optional: bool) -> None: "Attempting to traverse inside a fold while the _current_location was not a " f"FoldScopeLocation. _current_location was set to {self._current_location}." ) + # add_traversal performs the same check internally, but checking here is necessary in + # order to give the user a better error message - the vertex_field is not available from + # within add_traversal and therefore cannot be put in a user-friendly error message + # within the function. if not isinstance(edge, DirectJoinDescriptor): raise NotImplementedError( f"Edge {vertex_field} is backed by a CompositeJoinDescriptor, " @@ -1179,6 +1197,10 @@ def fold(self, fold_scope_location: FoldScopeLocation) -> None: join_descriptor = self._sql_schema_info.join_descriptors[self._current_classname][ full_edge_name ] + # add_traversal performs the same check internally, but checking here is necessary in + # order to give the user a better error message - the full_edge_name is not available from + # within add_traversal and therefore cannot be put in a user-friendly error message + # within the function. if not isinstance(join_descriptor, DirectJoinDescriptor): raise NotImplementedError( f"Edge {full_edge_name} requires a JOIN across a composite key, which is currently " diff --git a/graphql_compiler/schema/schema_info.py b/graphql_compiler/schema/schema_info.py index da44351fa..70a1b6883 100644 --- a/graphql_compiler/schema/schema_info.py +++ b/graphql_compiler/schema/schema_info.py @@ -226,56 +226,63 @@ def _create_sql_schema_info( ) -# Complete schema information sufficient to compile GraphQL queries to SQLAlchemy -# -# It describes the tables that correspond to each type (object type or interface type), -# and gives instructions on how to perform joins for each vertex field. The property fields on each -# type are implicitly mapped to columns with the same name on the corresponding table. -# -# NOTES: -# - RootSchemaQuery is a special type that does not need a corresponding table. -# - Builtin types like __Schema, __Type, etc. don't need corresponding tables. -# - Builtin fields like _x_count do not need corresponding columns. -SQLAlchemySchemaInfo = namedtuple( - "SQLAlchemySchemaInfo", - ( - # GraphQLSchema - "schema", - # optional dict of GraphQL interface or type -> GraphQL union. - # Used as a workaround for GraphQL's lack of support for - # inheritance across "types" (i.e. non-interfaces), as well as a - # workaround for Gremlin's total lack of inheritance-awareness. - # The key-value pairs in the dict specify that the "key" type - # is equivalent to the "value" type, i.e. that the GraphQL type or - # interface in the key is the most-derived common supertype - # of every GraphQL type in the "value" GraphQL union. - # Recursive expansion of type equivalence hints is not performed, - # and only type-level correctness of this argument is enforced. - # See README.md for more details on everything this parameter does. - # ***** - # Be very careful with this option, as bad input here will - # lead to incorrect output queries being generated. - # ***** - "type_equivalence_hints", - # sqlalchemy.engine.interfaces.Dialect, specifying the dialect we are compiling for - # (e.g. sqlalchemy.dialects.mssql.dialect()). - "dialect", - # dict mapping every graphql object type or interface type name in the schema to - # a sqlalchemy table. Column types that do not exist for this dialect are not allowed. - # All tables are expected to have primary keys. - "vertex_name_to_table", - # dict mapping every graphql object type or interface type name in the schema to: - # dict mapping every vertex field name at that type to a JoinDescriptor. The - # tables the join is to be performed on are not specified. They are inferred from - # the schema and the tables dictionary. - "join_descriptors", - ), -) +@dataclass +class SQLAlchemySchemaInfo: + """Complete schema information sufficient to compile GraphQL queries to SQLAlchemy. + + It describes the tables that correspond to each type (object type or interface type), + and gives instructions on how to perform joins for each vertex field. The property fields on + each type are implicitly mapped to columns with the same name on the corresponding table. + + Notes: + - RootSchemaQuery is a special type that does not need a corresponding table. + - Builtin types like __Schema, __Type, etc. don't need corresponding tables. + - Builtin fields like _x_count do not need corresponding columns. + + TODO: This class is essentially the same as SQLSchemaInfo. SQLSchemaInfo is part of an + incomplete refactor started in + https://github.com/kensho-technologies/graphql-compiler/pull/714 + SQLAlchemySchemaInfo is currently used to compile GraphQL to SQL while CommonSchemaInfo + is currently used to compile GraphQL to match, gremlin, and cypher. + """ + + schema: GraphQLSchema + + # Optional dict of GraphQL interface or type -> GraphQL union. + # Used as a workaround for GraphQL's lack of support for + # inheritance across "types" (i.e. non-interfaces), as well as a + # workaround for Gremlin's total lack of inheritance-awareness. + # The key-value pairs in the dict specify that the "key" type + # is equivalent to the "value" type, i.e. that the GraphQL type or + # interface in the key is the most-derived common supertype + # of every GraphQL type in the "value" GraphQL union. + # Recursive expansion of type equivalence hints is not performed, + # and only type-level correctness of this argument is enforced. + # See README.md for more details on everything this parameter does. + # ***** + # Be very careful with this option, as bad input here will + # lead to incorrect output queries being generated. + # ***** + type_equivalence_hints: Optional[TypeEquivalenceHintsType] + + # Specifying the SQL Dialect. + dialect: Dialect + + # Mapping every GraphQL object or interface type name in the schema to the corresponding + # SQLAlchemy table. Column types that do not exist for this dialect are not allowed. + # All tables are expected to have primary keys. + vertex_name_to_table: Dict[str, sqlalchemy.Table] + + # Mapping every GraphQL object or interface type name in the schema to: + # dict mapping every vertex field name at that type to a JoinDescriptor. The + # tables the join is to be performed on are not specified. They are inferred from + # the schema and the tables dictionary. + join_descriptors: Dict[str, Dict[str, JoinDescriptor]] def make_sqlalchemy_schema_info( schema: GraphQLSchema, - type_equivalence_hints: TypeEquivalenceHintsType, + type_equivalence_hints: Optional[TypeEquivalenceHintsType], dialect: Dialect, vertex_name_to_table: Dict[str, sqlalchemy.Table], join_descriptors: Dict[str, Dict[str, JoinDescriptor]], diff --git a/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py b/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py index d6af1cc7a..8883fc4d6 100644 --- a/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py +++ b/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py @@ -112,12 +112,26 @@ def test_table_vertex_representation_with_non_default_name(self) -> None: def test_represent_supported_fields(self) -> None: table1_graphql_object = self.schema_info.schema.get_type("Table1") + # mypy complained even with self.assertIsInstance(table1_graphql_object, GraphQLObjectType) + # so performing a manual check. + if not isinstance(table1_graphql_object, GraphQLObjectType): + raise AssertionError( + f"table1_graphql_object expected to be GraphQLObjectType, but was of type " + f"{type(table1_graphql_object)}" + ) self.assertEqual( table1_graphql_object.fields["column_with_supported_type"].type, GraphQLString ) def test_ignored_fields_not_supported(self) -> None: table1_graphql_object = self.schema_info.schema.get_type("Table1") + # mypy complained even with self.assertIsInstance(table1_graphql_object, GraphQLObjectType) + # so performing a manual check. + if not isinstance(table1_graphql_object, GraphQLObjectType): + raise AssertionError( + f"table1_graphql_object expected to be GraphQLObjectType, but was of type " + f"{type(table1_graphql_object)}" + ) self.assertTrue("column_with_non_supported_type" not in table1_graphql_object.fields) def test_warn_when_type_is_not_supported(self) -> None: @@ -140,11 +154,30 @@ def test_do_not_support_sql_tz_aware_datetime_types(self) -> None: def test_mssql_scalar_type_representation(self) -> None: table1_graphql_object = self.schema_info.schema.get_type("Table1") + # mypy complained even with self.assertIsInstance(table1_graphql_object, GraphQLObjectType) + # so performing a manual check. + if not isinstance(table1_graphql_object, GraphQLObjectType): + raise AssertionError( + f"table1_graphql_object expected to be GraphQLObjectType, but was of type " + f"{type(table1_graphql_object)}" + ) self.assertEqual(table1_graphql_object.fields["column_with_mssql_type"].type, GraphQLInt) def test_direct_sql_edge_representation(self) -> None: table1_graphql_object = self.schema_info.schema.get_type("Table1") arbitrarily_named_graphql_object = self.schema_info.schema.get_type("ArbitraryObjectName") + # mypy complained even with self.assertIsInstance(table1_graphql_object, GraphQLObjectType) + # so performing a manual check. + if not isinstance(table1_graphql_object, GraphQLObjectType): + raise AssertionError( + f"table1_graphql_object expected to be GraphQLObjectType, but was of type " + f"{type(table1_graphql_object)}" + ) + if not isinstance(arbitrarily_named_graphql_object, GraphQLObjectType): + raise AssertionError( + f"arbitrarily_named_graphql_object expected to be GraphQLObjectType, but was of " + f"type {type(arbitrarily_named_graphql_object)}" + ) self.assertEqual( table1_graphql_object.fields["out_test_edge"].type.of_type.name, "ArbitraryObjectName" )