diff --git a/graphql_compiler/schema/schema_info.py b/graphql_compiler/schema/schema_info.py index 1a2d8e312..df9cc2841 100644 --- a/graphql_compiler/schema/schema_info.py +++ b/graphql_compiler/schema/schema_info.py @@ -11,6 +11,7 @@ import six import sqlalchemy from sqlalchemy.dialects.mssql import dialect as mssql_dialect +from sqlalchemy.dialects.mssql.base import MSDialect from sqlalchemy.dialects.mysql import dialect as mysql_dialect from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect from sqlalchemy.engine.interfaces import Dialect @@ -280,6 +281,10 @@ class SQLAlchemySchemaInfo: # the schema and the tables dictionary. join_descriptors: Dict[str, Dict[str, JoinDescriptor]] + # Whether or not fold postprocessing should be applied to the results from queries + # compiled with this schema. + requires_fold_postprocessing: bool + def make_sqlalchemy_schema_info( schema: GraphQLSchema, @@ -288,6 +293,8 @@ def make_sqlalchemy_schema_info( vertex_name_to_table: Dict[str, sqlalchemy.Table], join_descriptors: Dict[str, Dict[str, JoinDescriptor]], validate: bool = True, + *, + requires_fold_postprocessing: Optional[bool] = None ) -> SQLAlchemySchemaInfo: """Make a SQLAlchemySchemaInfo if the input provided is valid. @@ -322,6 +329,9 @@ def make_sqlalchemy_schema_info( validate: whether to validate that the given inputs are valid for creation of a SQLAlchemySchemaInfo object. Disabling validation may improve performance for particularly large schemas, at the risk of constructing an invalid schema info. + requires_fold_postprocessing: whether or not queries compiled against this schema require + fold post-processing. If None, this will be inferred from the + dialect. Returns: SQLAlchemySchemaInfo containing the input arguments provided @@ -363,8 +373,20 @@ def make_sqlalchemy_schema_info( "for property field {}".format(type_name, field_name) ) + # Infer whether fold post-processing is required if not explicitly given. + if requires_fold_postprocessing is None: + if isinstance(dialect, MSDialect): + requires_fold_postprocessing = True + else: + requires_fold_postprocessing = False + return SQLAlchemySchemaInfo( - schema, type_equivalence_hints, dialect, vertex_name_to_table, join_descriptors + schema, + type_equivalence_hints, + dialect, + vertex_name_to_table, + join_descriptors, + requires_fold_postprocessing, ) diff --git a/graphql_compiler/schema_generation/sqlalchemy/__init__.py b/graphql_compiler/schema_generation/sqlalchemy/__init__.py index b7a518818..16999ed7b 100644 --- a/graphql_compiler/schema_generation/sqlalchemy/__init__.py +++ b/graphql_compiler/schema_generation/sqlalchemy/__init__.py @@ -1,13 +1,25 @@ # Copyright 2019-present Kensho Technologies, LLC. +from typing import Dict, Optional + +from graphql.type.definition import GraphQLType +from sqlalchemy import Table +from sqlalchemy.dialects.mssql.base import MSDialect +from sqlalchemy.engine.interfaces import Dialect + from ...schema.schema_info import SQLAlchemySchemaInfo from ..graphql_schema import get_graphql_schema_from_schema_graph -from .edge_descriptors import get_join_descriptors_from_edge_descriptors +from .edge_descriptors import EdgeDescriptor, get_join_descriptors_from_edge_descriptors from .schema_graph_builder import get_sqlalchemy_schema_graph def get_sqlalchemy_schema_info( - vertex_name_to_table, edges, dialect, class_to_field_type_overrides=None -): + vertex_name_to_table: Dict[str, Table], + edges: Dict[str, EdgeDescriptor], + dialect: Dialect, + class_to_field_type_overrides: Optional[Dict[str, Dict[str, GraphQLType]]] = None, + *, + requires_fold_postprocessing: Optional[bool] = None +) -> SQLAlchemySchemaInfo: """Return a SQLAlchemySchemaInfo from the metadata. Relational databases are supported by compiling to SQLAlchemy core as an intermediate @@ -23,29 +35,28 @@ def get_sqlalchemy_schema_info( the compiler relies on these to compile GraphQL to SQL. Args: - vertex_name_to_table: dict, str -> SQLAlchemy Table. This dictionary is used to generate the - GraphQL objects in the schema in the SQLAlchemySchemaInfo. Each - SQLAlchemyTable will be represented as a GraphQL object. The GraphQL - object names are the dictionary keys. The fields of the GraphQL - objects will be inferred from the columns of the underlying tables. - The fields will have the same name as the underlying columns and - columns with unsupported types, (SQL types with no matching GraphQL - type), will be ignored. - edges: dict, str-> EdgeDescriptor. The traversal of an edge - edge gets compiled to a SQL join in graphql_to_sql(). Therefore, each - EdgeDescriptor not only specifies the source and destination GraphQL - objects, but also which columns to use to use when generating a SQL join - between the underlying source and destination tables. The names of the edges - are the keys in the dictionary and the edges will be rendered as vertex fields - named out_ and in_ in the source and destination GraphQL - objects respectively. The edge names must not conflict with the GraphQL - object names. - dialect: sqlalchemy.engine.interfaces.Dialect, specifying the dialect we are compiling to - (e.g. sqlalchemy.dialects.mssql.dialect()). - class_to_field_type_overrides: optional dict, class name -> {field name -> field type}, - (string -> {string -> GraphQLType}). Used to override the - type of a field in the class where it's first defined and all - the class's subclasses. + vertex_name_to_table: dictionary used to generate the GraphQL objects in the schema + in the SQLAlchemySchemaInfo. Each SQLAlchemyTable will be represented + as a GraphQL object. The GraphQL object names are the dictionary keys. + The fields of the GraphQL objects will be inferred from the columns + of the underlying tables. The fields will have the same name as the + underlying columns and columns with unsupported types (SQL types + with no matching GraphQL type) will be ignored. + edges: dictionary mapping edge name to edge descriptor. The traversal of an edge + gets compiled to a SQL join in graphql_to_sql(). Therefore, each EdgeDescriptor not + only specifies the source and destination GraphQL objects, but also which columns to + use to use when generating a SQL join between the underlying source and destination + tables. The names of the edges are the keys in the dictionary and the edges will be + rendered as vertex fields named out_ and in_ in the source and + destination GraphQL objects respectively. The edge names must not conflict with the + GraphQL object names. + dialect: dialect we are compiling to (e.g. sqlalchemy.dialects.mssql.dialect()). + class_to_field_type_overrides: mapping class name to a dictionary of field name to field + type. Used to override the type of a field in the class where + it's first defined and all the class's subclasses. + requires_fold_postprocessing: whether or not queries compiled against this schema require + fold post-processing. If None, this will be inferred from the + dialect. Returns: SQLAlchemySchemaInfo containing the full information needed to compile SQL queries. @@ -60,6 +71,18 @@ def get_sqlalchemy_schema_info( join_descriptors = get_join_descriptors_from_edge_descriptors(edges) + # Infer whether fold post-processing is required if not explicitly given. + if requires_fold_postprocessing is None: + if isinstance(dialect, MSDialect): + requires_fold_postprocessing = True + else: + requires_fold_postprocessing = False + return SQLAlchemySchemaInfo( - graphql_schema, type_equivalence_hints, dialect, vertex_name_to_table, join_descriptors + graphql_schema, + type_equivalence_hints, + dialect, + vertex_name_to_table, + join_descriptors, + requires_fold_postprocessing, ) diff --git a/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py b/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py index 8883fc4d6..93d978f50 100644 --- a/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py +++ b/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py @@ -28,6 +28,7 @@ CompositeEdgeDescriptor, DirectEdgeDescriptor, DirectJoinDescriptor, + EdgeDescriptor, generate_direct_edge_descriptors_from_foreign_keys, ) from ...schema_generation.sqlalchemy.scalar_type_mapper import try_get_graphql_scalar_type @@ -99,7 +100,12 @@ def setUp(self): join_descriptors = get_join_descriptors_from_edge_descriptors(direct_edges) self.schema_info = SQLAlchemySchemaInfo( - graphql_schema, type_equivalence_hints, dialect, vertex_name_to_table, join_descriptors + graphql_schema, + type_equivalence_hints, + dialect, + vertex_name_to_table, + join_descriptors, + True, ) def test_table_vertex_representation(self) -> None: @@ -245,7 +251,7 @@ def test_index_generation_from_unique_constraint(self) -> None: ) def test_composite_edge(self) -> None: - edges = { + edges: Dict[str, EdgeDescriptor] = { "composite_edge": CompositeEdgeDescriptor( "Table1", "TableWithMultiplePrimaryKeyColumns", @@ -338,7 +344,7 @@ def setUp(self): self.vertex_name_to_table = _get_test_vertex_name_to_table() def test_reference_to_non_existent_source_vertex(self) -> None: - direct_edges = { + direct_edges: Dict[str, EdgeDescriptor] = { "invalid_source_vertex": DirectEdgeDescriptor( "InvalidVertexName", "source_column", "ArbitraryObjectName", "destination_column" ) @@ -347,7 +353,7 @@ def test_reference_to_non_existent_source_vertex(self) -> None: get_sqlalchemy_schema_info(self.vertex_name_to_table, direct_edges, dialect()) def test_reference_to_non_existent_destination_vertex(self) -> None: - direct_edges = { + direct_edges: Dict[str, EdgeDescriptor] = { "invalid_source_vertex": DirectEdgeDescriptor( "Table1", "source_column", "InvalidVertexName", "destination_column" ) @@ -356,7 +362,7 @@ def test_reference_to_non_existent_destination_vertex(self) -> None: get_sqlalchemy_schema_info(self.vertex_name_to_table, direct_edges, dialect()) def test_reference_to_non_existent_source_column(self) -> None: - direct_edges = { + direct_edges: Dict[str, EdgeDescriptor] = { "invalid_source_vertex": DirectEdgeDescriptor( "Table1", "invalid_column_name", "ArbitraryObjectName", "destination_column" ) @@ -365,7 +371,7 @@ def test_reference_to_non_existent_source_column(self) -> None: get_sqlalchemy_schema_info(self.vertex_name_to_table, direct_edges, dialect()) def test_reference_to_non_existent_destination_column(self) -> None: - direct_edges = { + direct_edges: Dict[str, EdgeDescriptor] = { "invalid_destination_column": DirectEdgeDescriptor( "Table1", "source_column", "ArbitraryObjectName", "invalid_column_name" ) diff --git a/mypy.ini b/mypy.ini index 21ab52826..479e6e2e9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -182,13 +182,17 @@ disallow_untyped_defs = False [mypy-graphql_compiler.schema_generation.sqlalchemy.*] disallow_untyped_calls = False -disallow_untyped_defs = False [mypy-graphql_compiler.schema_generation.sqlalchemy.schema_graph_builder.*] check_untyped_defs = False +disallow_untyped_defs = False [mypy-graphql_compiler.schema_generation.sqlalchemy.sqlalchemy_reflector.*] check_untyped_defs = False +disallow_untyped_defs = False + +[mypy-graphql_compiler.schema_generation.sqlalchemy.utils.*] +disallow_untyped_defs = False [mypy-graphql_compiler.schema_transformation.*] disallow_untyped_calls = False