From 0156d76fabe90e9938874515f6d61868611438bd Mon Sep 17 00:00:00 2001 From: Selene Chew Date: Tue, 27 Apr 2021 16:55:28 -0400 Subject: [PATCH 1/3] Add fold post-processing bool to SQLAlchemySchemaInfo --- graphql_compiler/schema/schema_info.py | 24 +++++- .../schema_generation/sqlalchemy/__init__.py | 77 ++++++++++++------- .../test_sqlalchemy_schema_generation.py | 18 +++-- 3 files changed, 85 insertions(+), 34 deletions(-) diff --git a/graphql_compiler/schema/schema_info.py b/graphql_compiler/schema/schema_info.py index 1a2d8e312..df9cc2841 100644 --- a/graphql_compiler/schema/schema_info.py +++ b/graphql_compiler/schema/schema_info.py @@ -11,6 +11,7 @@ import six import sqlalchemy from sqlalchemy.dialects.mssql import dialect as mssql_dialect +from sqlalchemy.dialects.mssql.base import MSDialect from sqlalchemy.dialects.mysql import dialect as mysql_dialect from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect from sqlalchemy.engine.interfaces import Dialect @@ -280,6 +281,10 @@ class SQLAlchemySchemaInfo: # the schema and the tables dictionary. join_descriptors: Dict[str, Dict[str, JoinDescriptor]] + # Whether or not fold postprocessing should be applied to the results from queries + # compiled with this schema. + requires_fold_postprocessing: bool + def make_sqlalchemy_schema_info( schema: GraphQLSchema, @@ -288,6 +293,8 @@ def make_sqlalchemy_schema_info( vertex_name_to_table: Dict[str, sqlalchemy.Table], join_descriptors: Dict[str, Dict[str, JoinDescriptor]], validate: bool = True, + *, + requires_fold_postprocessing: Optional[bool] = None ) -> SQLAlchemySchemaInfo: """Make a SQLAlchemySchemaInfo if the input provided is valid. @@ -322,6 +329,9 @@ def make_sqlalchemy_schema_info( validate: whether to validate that the given inputs are valid for creation of a SQLAlchemySchemaInfo object. Disabling validation may improve performance for particularly large schemas, at the risk of constructing an invalid schema info. + requires_fold_postprocessing: whether or not queries compiled against this schema require + fold post-processing. If None, this will be inferred from the + dialect. Returns: SQLAlchemySchemaInfo containing the input arguments provided @@ -363,8 +373,20 @@ def make_sqlalchemy_schema_info( "for property field {}".format(type_name, field_name) ) + # Infer whether fold post-processing is required if not explicitly given. + if requires_fold_postprocessing is None: + if isinstance(dialect, MSDialect): + requires_fold_postprocessing = True + else: + requires_fold_postprocessing = False + return SQLAlchemySchemaInfo( - schema, type_equivalence_hints, dialect, vertex_name_to_table, join_descriptors + schema, + type_equivalence_hints, + dialect, + vertex_name_to_table, + join_descriptors, + requires_fold_postprocessing, ) diff --git a/graphql_compiler/schema_generation/sqlalchemy/__init__.py b/graphql_compiler/schema_generation/sqlalchemy/__init__.py index b7a518818..3406d85d3 100644 --- a/graphql_compiler/schema_generation/sqlalchemy/__init__.py +++ b/graphql_compiler/schema_generation/sqlalchemy/__init__.py @@ -1,13 +1,25 @@ # Copyright 2019-present Kensho Technologies, LLC. +from typing import Dict, Optional + +from graphql.type.definition import GraphQLType +from sqlalchemy import Table +from sqlalchemy.dialects.mssql.base import MSDialect +from sqlalchemy.engine.interfaces import Dialect + from ...schema.schema_info import SQLAlchemySchemaInfo from ..graphql_schema import get_graphql_schema_from_schema_graph -from .edge_descriptors import get_join_descriptors_from_edge_descriptors +from .edge_descriptors import EdgeDescriptor, get_join_descriptors_from_edge_descriptors from .schema_graph_builder import get_sqlalchemy_schema_graph def get_sqlalchemy_schema_info( - vertex_name_to_table, edges, dialect, class_to_field_type_overrides=None -): + vertex_name_to_table: Dict[str, Table], + edges: Dict[str, EdgeDescriptor], + dialect: Dialect, + class_to_field_type_overrides: Optional[Dict[str, Dict[str, GraphQLType]]] = None, + *, + requires_fold_postprocessing: Optional[bool] = None +) -> SQLAlchemySchemaInfo: """Return a SQLAlchemySchemaInfo from the metadata. Relational databases are supported by compiling to SQLAlchemy core as an intermediate @@ -23,29 +35,28 @@ def get_sqlalchemy_schema_info( the compiler relies on these to compile GraphQL to SQL. Args: - vertex_name_to_table: dict, str -> SQLAlchemy Table. This dictionary is used to generate the - GraphQL objects in the schema in the SQLAlchemySchemaInfo. Each - SQLAlchemyTable will be represented as a GraphQL object. The GraphQL - object names are the dictionary keys. The fields of the GraphQL - objects will be inferred from the columns of the underlying tables. - The fields will have the same name as the underlying columns and - columns with unsupported types, (SQL types with no matching GraphQL - type), will be ignored. - edges: dict, str-> EdgeDescriptor. The traversal of an edge - edge gets compiled to a SQL join in graphql_to_sql(). Therefore, each - EdgeDescriptor not only specifies the source and destination GraphQL - objects, but also which columns to use to use when generating a SQL join - between the underlying source and destination tables. The names of the edges - are the keys in the dictionary and the edges will be rendered as vertex fields - named out_ and in_ in the source and destination GraphQL - objects respectively. The edge names must not conflict with the GraphQL - object names. - dialect: sqlalchemy.engine.interfaces.Dialect, specifying the dialect we are compiling to - (e.g. sqlalchemy.dialects.mssql.dialect()). - class_to_field_type_overrides: optional dict, class name -> {field name -> field type}, - (string -> {string -> GraphQLType}). Used to override the - type of a field in the class where it's first defined and all - the class's subclasses. + vertex_name_to_table: Dictionary used to generate the GraphQL objects in the schema + in the SQLAlchemySchemaInfo. Each SQLAlchemyTable will be represented + as a GraphQL object. The GraphQL object names are the dictionary keys. + The fields of the GraphQL objects will be inferred from the columns + of the underlying tables. The fields will have the same name as the + underlying columns and columns with unsupported types (SQL types + with no matching GraphQL type) will be ignored. + edges: Dictionary mapping edge name to edge descriptor. The traversal of an edge + gets compiled to a SQL join in graphql_to_sql(). Therefore, each EdgeDescriptor not + only specifies the source and destination GraphQL objects, but also which columns to + use to use when generating a SQL join between the underlying source and destination + tables. The names of the edges are the keys in the dictionary and the edges will be + rendered as vertex fields named out_ and in_ in the source and + destination GraphQL objects respectively. The edge names must not conflict with the + GraphQL object names. + dialect: dialect we are compiling to (e.g. sqlalchemy.dialects.mssql.dialect()). + class_to_field_type_overrides: Mapping class name to a dictionary of field name to field + type. Used to override the type of a field in the class where + it's first defined and all the class's subclasses. + requires_fold_postprocessing: whether or not queries compiled against this schema require + fold post-processing. If None, this will be inferred from the + dialect. Returns: SQLAlchemySchemaInfo containing the full information needed to compile SQL queries. @@ -60,6 +71,18 @@ def get_sqlalchemy_schema_info( join_descriptors = get_join_descriptors_from_edge_descriptors(edges) + # Infer whether fold post-processing is required if not explicitly given. + if requires_fold_postprocessing is None: + if isinstance(dialect, MSDialect): + requires_fold_postprocessing = True + else: + requires_fold_postprocessing = False + return SQLAlchemySchemaInfo( - graphql_schema, type_equivalence_hints, dialect, vertex_name_to_table, join_descriptors + graphql_schema, + type_equivalence_hints, + dialect, + vertex_name_to_table, + join_descriptors, + requires_fold_postprocessing, ) diff --git a/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py b/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py index 8883fc4d6..93d978f50 100644 --- a/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py +++ b/graphql_compiler/tests/schema_generation_tests/test_sqlalchemy_schema_generation.py @@ -28,6 +28,7 @@ CompositeEdgeDescriptor, DirectEdgeDescriptor, DirectJoinDescriptor, + EdgeDescriptor, generate_direct_edge_descriptors_from_foreign_keys, ) from ...schema_generation.sqlalchemy.scalar_type_mapper import try_get_graphql_scalar_type @@ -99,7 +100,12 @@ def setUp(self): join_descriptors = get_join_descriptors_from_edge_descriptors(direct_edges) self.schema_info = SQLAlchemySchemaInfo( - graphql_schema, type_equivalence_hints, dialect, vertex_name_to_table, join_descriptors + graphql_schema, + type_equivalence_hints, + dialect, + vertex_name_to_table, + join_descriptors, + True, ) def test_table_vertex_representation(self) -> None: @@ -245,7 +251,7 @@ def test_index_generation_from_unique_constraint(self) -> None: ) def test_composite_edge(self) -> None: - edges = { + edges: Dict[str, EdgeDescriptor] = { "composite_edge": CompositeEdgeDescriptor( "Table1", "TableWithMultiplePrimaryKeyColumns", @@ -338,7 +344,7 @@ def setUp(self): self.vertex_name_to_table = _get_test_vertex_name_to_table() def test_reference_to_non_existent_source_vertex(self) -> None: - direct_edges = { + direct_edges: Dict[str, EdgeDescriptor] = { "invalid_source_vertex": DirectEdgeDescriptor( "InvalidVertexName", "source_column", "ArbitraryObjectName", "destination_column" ) @@ -347,7 +353,7 @@ def test_reference_to_non_existent_source_vertex(self) -> None: get_sqlalchemy_schema_info(self.vertex_name_to_table, direct_edges, dialect()) def test_reference_to_non_existent_destination_vertex(self) -> None: - direct_edges = { + direct_edges: Dict[str, EdgeDescriptor] = { "invalid_source_vertex": DirectEdgeDescriptor( "Table1", "source_column", "InvalidVertexName", "destination_column" ) @@ -356,7 +362,7 @@ def test_reference_to_non_existent_destination_vertex(self) -> None: get_sqlalchemy_schema_info(self.vertex_name_to_table, direct_edges, dialect()) def test_reference_to_non_existent_source_column(self) -> None: - direct_edges = { + direct_edges: Dict[str, EdgeDescriptor] = { "invalid_source_vertex": DirectEdgeDescriptor( "Table1", "invalid_column_name", "ArbitraryObjectName", "destination_column" ) @@ -365,7 +371,7 @@ def test_reference_to_non_existent_source_column(self) -> None: get_sqlalchemy_schema_info(self.vertex_name_to_table, direct_edges, dialect()) def test_reference_to_non_existent_destination_column(self) -> None: - direct_edges = { + direct_edges: Dict[str, EdgeDescriptor] = { "invalid_destination_column": DirectEdgeDescriptor( "Table1", "source_column", "ArbitraryObjectName", "invalid_column_name" ) From 506bb3ededfd809d5d7428cd8419fb3167c4c014 Mon Sep 17 00:00:00 2001 From: Selene Chew Date: Tue, 27 Apr 2021 16:58:54 -0400 Subject: [PATCH 2/3] fix capitalization --- graphql_compiler/schema_generation/sqlalchemy/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/graphql_compiler/schema_generation/sqlalchemy/__init__.py b/graphql_compiler/schema_generation/sqlalchemy/__init__.py index 3406d85d3..16999ed7b 100644 --- a/graphql_compiler/schema_generation/sqlalchemy/__init__.py +++ b/graphql_compiler/schema_generation/sqlalchemy/__init__.py @@ -35,14 +35,14 @@ def get_sqlalchemy_schema_info( the compiler relies on these to compile GraphQL to SQL. Args: - vertex_name_to_table: Dictionary used to generate the GraphQL objects in the schema + vertex_name_to_table: dictionary used to generate the GraphQL objects in the schema in the SQLAlchemySchemaInfo. Each SQLAlchemyTable will be represented as a GraphQL object. The GraphQL object names are the dictionary keys. The fields of the GraphQL objects will be inferred from the columns of the underlying tables. The fields will have the same name as the underlying columns and columns with unsupported types (SQL types with no matching GraphQL type) will be ignored. - edges: Dictionary mapping edge name to edge descriptor. The traversal of an edge + edges: dictionary mapping edge name to edge descriptor. The traversal of an edge gets compiled to a SQL join in graphql_to_sql(). Therefore, each EdgeDescriptor not only specifies the source and destination GraphQL objects, but also which columns to use to use when generating a SQL join between the underlying source and destination @@ -51,7 +51,7 @@ def get_sqlalchemy_schema_info( destination GraphQL objects respectively. The edge names must not conflict with the GraphQL object names. dialect: dialect we are compiling to (e.g. sqlalchemy.dialects.mssql.dialect()). - class_to_field_type_overrides: Mapping class name to a dictionary of field name to field + class_to_field_type_overrides: mapping class name to a dictionary of field name to field type. Used to override the type of a field in the class where it's first defined and all the class's subclasses. requires_fold_postprocessing: whether or not queries compiled against this schema require From c4ebaba525c556d86f3fb645ca2e0b95baf1399e Mon Sep 17 00:00:00 2001 From: Selene Chew Date: Tue, 27 Apr 2021 17:13:48 -0400 Subject: [PATCH 3/3] tighten typing copilot --- mypy.ini | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mypy.ini b/mypy.ini index 21ab52826..479e6e2e9 100644 --- a/mypy.ini +++ b/mypy.ini @@ -182,13 +182,17 @@ disallow_untyped_defs = False [mypy-graphql_compiler.schema_generation.sqlalchemy.*] disallow_untyped_calls = False -disallow_untyped_defs = False [mypy-graphql_compiler.schema_generation.sqlalchemy.schema_graph_builder.*] check_untyped_defs = False +disallow_untyped_defs = False [mypy-graphql_compiler.schema_generation.sqlalchemy.sqlalchemy_reflector.*] check_untyped_defs = False +disallow_untyped_defs = False + +[mypy-graphql_compiler.schema_generation.sqlalchemy.utils.*] +disallow_untyped_defs = False [mypy-graphql_compiler.schema_transformation.*] disallow_untyped_calls = False