1
1
# Copyright 2018-present Kensho Technologies, LLC.
2
2
"""Transform a SqlNode tree into an executable SQLAlchemy query."""
3
3
from dataclasses import dataclass
4
- from typing import Dict , Iterator , List , NamedTuple , Optional , Set , Tuple , Union
4
+ from typing import AbstractSet , Dict , Iterator , List , NamedTuple , Optional , Set , Tuple , Union
5
5
6
6
import six
7
7
import sqlalchemy
21
21
from . import blocks
22
22
from ..global_utils import VertexPath
23
23
from ..schema import COUNT_META_FIELD_NAME
24
- from ..schema .schema_info import DirectJoinDescriptor , SQLAlchemySchemaInfo
24
+ from ..schema .schema_info import (
25
+ CompositeJoinDescriptor ,
26
+ DirectJoinDescriptor ,
27
+ JoinDescriptor ,
28
+ SQLAlchemySchemaInfo ,
29
+ )
25
30
from .compiler_entities import BasicBlock
26
31
from .compiler_frontend import IrAndMetadata
27
32
from .expressions import ContextField , Expression
@@ -143,21 +148,30 @@ def _find_used_columns(
143
148
)
144
149
vertex_field_name = f"{ edge_direction } _{ edge_name } "
145
150
edge = sql_schema_info .join_descriptors [location_info .type .name ][vertex_field_name ]
146
- used_columns .setdefault (get_vertex_path (location ), set ()).add (edge .from_column )
147
- used_columns .setdefault (get_vertex_path (child_location ), set ()).add (edge .to_column )
151
+ if isinstance (edge , DirectJoinDescriptor ):
152
+ columns_at_location = {edge .from_column }
153
+ columns_at_child = {edge .to_column }
154
+ elif isinstance (edge , CompositeJoinDescriptor ):
155
+ columns_at_location = {column_pair [0 ] for column_pair in edge .column_pairs }
156
+ columns_at_child = {column_pair [1 ] for column_pair in edge .column_pairs }
157
+ else :
158
+ raise AssertionError (f"Unknown join descriptor type { edge } : { type (edge )} " )
159
+
160
+ used_columns .setdefault (get_vertex_path (location ), set ()).update (columns_at_location )
161
+ used_columns .setdefault (get_vertex_path (child_location ), set ()).update (columns_at_child )
148
162
149
163
# Check if the edge is recursive
150
164
child_location_info = ir .query_metadata_table .get_location_info (child_location )
151
165
if child_location_info .recursive_scopes_depth > location_info .recursive_scopes_depth :
152
166
# The primary key may be used if the recursive cte base semijoins to
153
167
# the pre-recurse cte by primary key.
154
168
alias = sql_schema_info .vertex_name_to_table [location_info .type .name ].alias ()
155
- primary_key_name = _get_primary_key_name ( alias , location_info . type . name , "@recurse" )
156
- used_columns .setdefault (get_vertex_path (location ), set ()).add ( primary_key_name )
169
+ primary_keys = { column . name for column in alias . primary_key }
170
+ used_columns .setdefault (get_vertex_path (location ), set ()).update ( primary_keys )
157
171
158
172
# The from_column is used at the destination as well, inside the recursive step
159
- used_columns .setdefault (get_vertex_path (child_location ), set ()).add (
160
- edge . from_column
173
+ used_columns .setdefault (get_vertex_path (child_location ), set ()).update (
174
+ columns_at_location
161
175
)
162
176
163
177
# Find outputs used
@@ -780,7 +794,9 @@ def __init__(self, sql_schema_info: SQLAlchemySchemaInfo, ir: IrAndMetadata):
780
794
# Move to the beginning location of the query.
781
795
self ._relocate (ir .query_metadata_table .root_location )
782
796
783
- # Mapping aliases to the column used to join into them.
797
+ # Mapping aliases to one of the column used to join into them. We use this column
798
+ # to check for LEFT JOIN misses, since it helps us distinguish actuall NULL values
799
+ # from values that are NULL because of a LEFT JOIN miss.
784
800
self ._came_from : Dict [Union [Alias , ColumnRouter ], Column ] = {}
785
801
786
802
self ._recurse_needs_cte : bool = False
@@ -840,9 +856,8 @@ def _relocate(self, new_location: BaseLocation):
840
856
self ._current_alias , self ._current_location , output_fields
841
857
)
842
858
843
- # TODO merge from_column and to_column into a joindescriptor
844
859
def _join_to_parent_location (
845
- self , parent_alias : Alias , from_column : str , to_column : str , optional : bool
860
+ self , parent_alias : Alias , join_descriptor : JoinDescriptor , optional : bool
846
861
):
847
862
"""Join the current location to the parent location using the column names specified."""
848
863
if self ._current_alias is None :
@@ -851,7 +866,25 @@ def _join_to_parent_location(
851
866
f"during fold { self } ."
852
867
)
853
868
854
- self ._came_from [self ._current_alias ] = self ._current_alias .c [to_column ]
869
+ # construct on clause for join
870
+ if isinstance (join_descriptor , DirectJoinDescriptor ):
871
+ matching_column_pairs : AbstractSet [Tuple [str , str ]] = {
872
+ (join_descriptor .from_column , join_descriptor .to_column ),
873
+ }
874
+ elif isinstance (join_descriptor , CompositeJoinDescriptor ):
875
+ matching_column_pairs = join_descriptor .column_pairs
876
+ else :
877
+ raise AssertionError (
878
+ f"Unknown join descriptor type { join_descriptor } : { type (join_descriptor )} "
879
+ )
880
+
881
+ if not matching_column_pairs :
882
+ raise AssertionError (
883
+ f"Invalid join descriptor { join_descriptor } , produced no matching column pairs."
884
+ )
885
+
886
+ _ , non_null_column = sorted (matching_column_pairs )[0 ]
887
+ self ._came_from [self ._current_alias ] = self ._current_alias .c [non_null_column ]
855
888
856
889
if self ._is_in_optional_scope () and not optional :
857
890
# For mandatory edges in optional scope, we emit LEFT OUTER JOIN and enforce the
@@ -879,10 +912,17 @@ def _join_to_parent_location(
879
912
)
880
913
)
881
914
915
+ on_clause = sqlalchemy .and_ (
916
+ * (
917
+ parent_alias .c [from_column ] == self ._current_alias .c [to_column ]
918
+ for from_column , to_column in sorted (matching_column_pairs )
919
+ )
920
+ )
921
+
882
922
# Join to where we came from.
883
923
self ._from_clause = self ._from_clause .join (
884
924
self ._current_alias ,
885
- onclause = ( parent_alias . c [ from_column ] == self . _current_alias . c [ to_column ]) ,
925
+ onclause = on_clause ,
886
926
isouter = self ._is_in_optional_scope (),
887
927
)
888
928
@@ -932,11 +972,14 @@ def traverse(self, vertex_field: str, optional: bool) -> None:
932
972
"Attempting to traverse inside a fold while the _current_location was not a "
933
973
f"FoldScopeLocation. _current_location was set to { self ._current_location } ."
934
974
)
975
+ if not isinstance (edge , DirectJoinDescriptor ):
976
+ raise NotImplementedError (
977
+ f"Edge { vertex_field } is backed by a CompositeJoinDescriptor, "
978
+ "so it can't be used inside a @fold scope."
979
+ )
935
980
self ._current_fold .add_traversal (edge , previous_alias , self ._current_alias )
936
981
else :
937
- self ._join_to_parent_location (
938
- previous_alias , edge .from_column , edge .to_column , optional
939
- )
982
+ self ._join_to_parent_location (previous_alias , edge , optional )
940
983
941
984
def _wrap_into_cte (self ) -> None :
942
985
"""Wrap the current query into a cte."""
@@ -1017,6 +1060,11 @@ def recurse(self, vertex_field: str, depth: int) -> None:
1017
1060
)
1018
1061
1019
1062
edge = self ._sql_schema_info .join_descriptors [self ._current_classname ][vertex_field ]
1063
+ if not isinstance (edge , DirectJoinDescriptor ):
1064
+ raise NotImplementedError (
1065
+ f"Edge { vertex_field } requires a JOIN across a composite key, which is currently "
1066
+ f"not supported for use with @recurse."
1067
+ )
1020
1068
primary_key = self ._get_current_primary_key_name ("@recurse" )
1021
1069
1022
1070
# Wrap the query so far into a CTE if it would speed up the recursive query.
@@ -1074,8 +1122,8 @@ def recurse(self, vertex_field: str, depth: int) -> None:
1074
1122
.where (base .c [CTE_DEPTH_NAME ] < literal_depth )
1075
1123
)
1076
1124
1077
- # TODO(bojanserafimov): This creates an unused alias if there's no tags or outputs so far
1078
- self ._join_to_parent_location (previous_alias , primary_key , CTE_KEY_NAME , False )
1125
+ join_descriptor = DirectJoinDescriptor ( primary_key , CTE_KEY_NAME )
1126
+ self ._join_to_parent_location (previous_alias , join_descriptor , False )
1079
1127
1080
1128
def start_global_operations (self ) -> None :
1081
1129
"""Execute a GlobalOperationsStart block."""
@@ -1131,6 +1179,11 @@ def fold(self, fold_scope_location: FoldScopeLocation) -> None:
1131
1179
join_descriptor = self ._sql_schema_info .join_descriptors [self ._current_classname ][
1132
1180
full_edge_name
1133
1181
]
1182
+ if not isinstance (join_descriptor , DirectJoinDescriptor ):
1183
+ raise NotImplementedError (
1184
+ f"Edge { full_edge_name } requires a JOIN across a composite key, which is currently "
1185
+ "not supported for use with @fold."
1186
+ )
1134
1187
1135
1188
# 3. Initialize fold object.
1136
1189
self ._current_fold = FoldSubqueryBuilder (
0 commit comments