Skip to content
This repository was archived by the owner on Feb 6, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion graphql_compiler/schema/schema_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import six
import sqlalchemy
from sqlalchemy.dialects.mssql import dialect as mssql_dialect
from sqlalchemy.dialects.mssql.base import MSDialect
from sqlalchemy.dialects.mysql import dialect as mysql_dialect
from sqlalchemy.dialects.postgresql import dialect as postgresql_dialect
from sqlalchemy.engine.interfaces import Dialect
Expand Down Expand Up @@ -280,6 +281,10 @@ class SQLAlchemySchemaInfo:
# the schema and the tables dictionary.
join_descriptors: Dict[str, Dict[str, JoinDescriptor]]

# Whether or not fold postprocessing should be applied to the results from queries
# compiled with this schema.
requires_fold_postprocessing: bool


def make_sqlalchemy_schema_info(
schema: GraphQLSchema,
Expand All @@ -288,6 +293,8 @@ def make_sqlalchemy_schema_info(
vertex_name_to_table: Dict[str, sqlalchemy.Table],
join_descriptors: Dict[str, Dict[str, JoinDescriptor]],
validate: bool = True,
*,
requires_fold_postprocessing: Optional[bool] = None
) -> SQLAlchemySchemaInfo:
"""Make a SQLAlchemySchemaInfo if the input provided is valid.

Expand Down Expand Up @@ -322,6 +329,9 @@ def make_sqlalchemy_schema_info(
validate: whether to validate that the given inputs are valid for creation of
a SQLAlchemySchemaInfo object. Disabling validation may improve performance for
particularly large schemas, at the risk of constructing an invalid schema info.
requires_fold_postprocessing: whether or not queries compiled against this schema require
fold post-processing. If None, this will be inferred from the
dialect.

Returns:
SQLAlchemySchemaInfo containing the input arguments provided
Expand Down Expand Up @@ -363,8 +373,20 @@ def make_sqlalchemy_schema_info(
"for property field {}".format(type_name, field_name)
)

# Infer whether fold post-processing is required if not explicitly given.
if requires_fold_postprocessing is None:
if isinstance(dialect, MSDialect):
requires_fold_postprocessing = True
else:
requires_fold_postprocessing = False

return SQLAlchemySchemaInfo(
schema, type_equivalence_hints, dialect, vertex_name_to_table, join_descriptors
schema,
type_equivalence_hints,
dialect,
vertex_name_to_table,
join_descriptors,
requires_fold_postprocessing,
)


Expand Down
77 changes: 50 additions & 27 deletions graphql_compiler/schema_generation/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
# Copyright 2019-present Kensho Technologies, LLC.
from typing import Dict, Optional

from graphql.type.definition import GraphQLType
from sqlalchemy import Table
from sqlalchemy.dialects.mssql.base import MSDialect
from sqlalchemy.engine.interfaces import Dialect

from ...schema.schema_info import SQLAlchemySchemaInfo
from ..graphql_schema import get_graphql_schema_from_schema_graph
from .edge_descriptors import get_join_descriptors_from_edge_descriptors
from .edge_descriptors import EdgeDescriptor, get_join_descriptors_from_edge_descriptors
from .schema_graph_builder import get_sqlalchemy_schema_graph


def get_sqlalchemy_schema_info(
vertex_name_to_table, edges, dialect, class_to_field_type_overrides=None
):
vertex_name_to_table: Dict[str, Table],
edges: Dict[str, EdgeDescriptor],
dialect: Dialect,
class_to_field_type_overrides: Optional[Dict[str, Dict[str, GraphQLType]]] = None,
*,
requires_fold_postprocessing: Optional[bool] = None
) -> SQLAlchemySchemaInfo:
"""Return a SQLAlchemySchemaInfo from the metadata.

Relational databases are supported by compiling to SQLAlchemy core as an intermediate
Expand All @@ -23,29 +35,28 @@ def get_sqlalchemy_schema_info(
the compiler relies on these to compile GraphQL to SQL.

Args:
vertex_name_to_table: dict, str -> SQLAlchemy Table. This dictionary is used to generate the
GraphQL objects in the schema in the SQLAlchemySchemaInfo. Each
SQLAlchemyTable will be represented as a GraphQL object. The GraphQL
object names are the dictionary keys. The fields of the GraphQL
objects will be inferred from the columns of the underlying tables.
The fields will have the same name as the underlying columns and
columns with unsupported types, (SQL types with no matching GraphQL
type), will be ignored.
edges: dict, str-> EdgeDescriptor. The traversal of an edge
edge gets compiled to a SQL join in graphql_to_sql(). Therefore, each
EdgeDescriptor not only specifies the source and destination GraphQL
objects, but also which columns to use to use when generating a SQL join
between the underlying source and destination tables. The names of the edges
are the keys in the dictionary and the edges will be rendered as vertex fields
named out_<edgeName> and in_<edgeName> in the source and destination GraphQL
objects respectively. The edge names must not conflict with the GraphQL
object names.
dialect: sqlalchemy.engine.interfaces.Dialect, specifying the dialect we are compiling to
(e.g. sqlalchemy.dialects.mssql.dialect()).
class_to_field_type_overrides: optional dict, class name -> {field name -> field type},
(string -> {string -> GraphQLType}). Used to override the
type of a field in the class where it's first defined and all
the class's subclasses.
vertex_name_to_table: dictionary used to generate the GraphQL objects in the schema
in the SQLAlchemySchemaInfo. Each SQLAlchemyTable will be represented
as a GraphQL object. The GraphQL object names are the dictionary keys.
The fields of the GraphQL objects will be inferred from the columns
of the underlying tables. The fields will have the same name as the
underlying columns and columns with unsupported types (SQL types
with no matching GraphQL type) will be ignored.
edges: dictionary mapping edge name to edge descriptor. The traversal of an edge
gets compiled to a SQL join in graphql_to_sql(). Therefore, each EdgeDescriptor not
only specifies the source and destination GraphQL objects, but also which columns to
use to use when generating a SQL join between the underlying source and destination
tables. The names of the edges are the keys in the dictionary and the edges will be
rendered as vertex fields named out_<edgeName> and in_<edgeName> in the source and
destination GraphQL objects respectively. The edge names must not conflict with the
GraphQL object names.
dialect: dialect we are compiling to (e.g. sqlalchemy.dialects.mssql.dialect()).
class_to_field_type_overrides: mapping class name to a dictionary of field name to field
type. Used to override the type of a field in the class where
it's first defined and all the class's subclasses.
requires_fold_postprocessing: whether or not queries compiled against this schema require
fold post-processing. If None, this will be inferred from the
dialect.

Returns:
SQLAlchemySchemaInfo containing the full information needed to compile SQL queries.
Expand All @@ -60,6 +71,18 @@ def get_sqlalchemy_schema_info(

join_descriptors = get_join_descriptors_from_edge_descriptors(edges)

# Infer whether fold post-processing is required if not explicitly given.
if requires_fold_postprocessing is None:
if isinstance(dialect, MSDialect):
requires_fold_postprocessing = True
else:
requires_fold_postprocessing = False

return SQLAlchemySchemaInfo(
graphql_schema, type_equivalence_hints, dialect, vertex_name_to_table, join_descriptors
graphql_schema,
type_equivalence_hints,
dialect,
vertex_name_to_table,
join_descriptors,
requires_fold_postprocessing,
)
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
CompositeEdgeDescriptor,
DirectEdgeDescriptor,
DirectJoinDescriptor,
EdgeDescriptor,
generate_direct_edge_descriptors_from_foreign_keys,
)
from ...schema_generation.sqlalchemy.scalar_type_mapper import try_get_graphql_scalar_type
Expand Down Expand Up @@ -99,7 +100,12 @@ def setUp(self):
join_descriptors = get_join_descriptors_from_edge_descriptors(direct_edges)

self.schema_info = SQLAlchemySchemaInfo(
graphql_schema, type_equivalence_hints, dialect, vertex_name_to_table, join_descriptors
graphql_schema,
type_equivalence_hints,
dialect,
vertex_name_to_table,
join_descriptors,
True,
)

def test_table_vertex_representation(self) -> None:
Expand Down Expand Up @@ -245,7 +251,7 @@ def test_index_generation_from_unique_constraint(self) -> None:
)

def test_composite_edge(self) -> None:
edges = {
edges: Dict[str, EdgeDescriptor] = {
"composite_edge": CompositeEdgeDescriptor(
"Table1",
"TableWithMultiplePrimaryKeyColumns",
Expand Down Expand Up @@ -338,7 +344,7 @@ def setUp(self):
self.vertex_name_to_table = _get_test_vertex_name_to_table()

def test_reference_to_non_existent_source_vertex(self) -> None:
direct_edges = {
direct_edges: Dict[str, EdgeDescriptor] = {
"invalid_source_vertex": DirectEdgeDescriptor(
"InvalidVertexName", "source_column", "ArbitraryObjectName", "destination_column"
)
Expand All @@ -347,7 +353,7 @@ def test_reference_to_non_existent_source_vertex(self) -> None:
get_sqlalchemy_schema_info(self.vertex_name_to_table, direct_edges, dialect())

def test_reference_to_non_existent_destination_vertex(self) -> None:
direct_edges = {
direct_edges: Dict[str, EdgeDescriptor] = {
"invalid_source_vertex": DirectEdgeDescriptor(
"Table1", "source_column", "InvalidVertexName", "destination_column"
)
Expand All @@ -356,7 +362,7 @@ def test_reference_to_non_existent_destination_vertex(self) -> None:
get_sqlalchemy_schema_info(self.vertex_name_to_table, direct_edges, dialect())

def test_reference_to_non_existent_source_column(self) -> None:
direct_edges = {
direct_edges: Dict[str, EdgeDescriptor] = {
"invalid_source_vertex": DirectEdgeDescriptor(
"Table1", "invalid_column_name", "ArbitraryObjectName", "destination_column"
)
Expand All @@ -365,7 +371,7 @@ def test_reference_to_non_existent_source_column(self) -> None:
get_sqlalchemy_schema_info(self.vertex_name_to_table, direct_edges, dialect())

def test_reference_to_non_existent_destination_column(self) -> None:
direct_edges = {
direct_edges: Dict[str, EdgeDescriptor] = {
"invalid_destination_column": DirectEdgeDescriptor(
"Table1", "source_column", "ArbitraryObjectName", "invalid_column_name"
)
Expand Down
6 changes: 5 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -182,13 +182,17 @@ disallow_untyped_defs = False

[mypy-graphql_compiler.schema_generation.sqlalchemy.*]
disallow_untyped_calls = False
disallow_untyped_defs = False

[mypy-graphql_compiler.schema_generation.sqlalchemy.schema_graph_builder.*]
check_untyped_defs = False
disallow_untyped_defs = False

[mypy-graphql_compiler.schema_generation.sqlalchemy.sqlalchemy_reflector.*]
check_untyped_defs = False
disallow_untyped_defs = False

[mypy-graphql_compiler.schema_generation.sqlalchemy.utils.*]
disallow_untyped_defs = False

[mypy-graphql_compiler.schema_transformation.*]
disallow_untyped_calls = False
Expand Down