26
26
27
27
Selection = Tuple [Tuple [scalar_exprs .Expression , str ], ...]
28
28
29
+ REWRITABLE_NODE_TYPES = (
30
+ nodes .ProjectionNode ,
31
+ nodes .FilterNode ,
32
+ nodes .ReversedNode ,
33
+ nodes .OrderByNode ,
34
+ )
35
+
29
36
30
37
@dataclasses .dataclass (frozen = True )
31
38
class SquashedSelect :
32
- """Squash together as many nodes as possible , separating out the projection, filter and reordering expressions."""
39
+ """Squash nodes together until target node , separating out the projection, filter and reordering expressions."""
33
40
34
41
root : nodes .BigFrameNode
35
42
columns : Tuple [Tuple [scalar_exprs .Expression , str ], ...]
@@ -38,25 +45,25 @@ class SquashedSelect:
38
45
reverse_root : bool = False
39
46
40
47
@classmethod
41
- def from_node (
42
- cls , node : nodes .BigFrameNode , projections_only : bool = False
48
+ def from_node_span (
49
+ cls , node : nodes .BigFrameNode , target : nodes . BigFrameNode
43
50
) -> SquashedSelect :
44
- if isinstance (node , nodes .ProjectionNode ):
45
- return cls .from_node (node .child , projections_only = projections_only ).project (
46
- node .assignments
47
- )
48
- elif not projections_only and isinstance (node , nodes .FilterNode ):
49
- return cls .from_node (node .child ).filter (node .predicate )
50
- elif not projections_only and isinstance (node , nodes .ReversedNode ):
51
- return cls .from_node (node .child ).reverse ()
52
- elif not projections_only and isinstance (node , nodes .OrderByNode ):
53
- return cls .from_node (node .child ).order_with (node .by )
54
- else :
51
+ if node == target :
55
52
selection = tuple (
56
53
(scalar_exprs .UnboundVariableExpression (id ), id )
57
54
for id in get_node_column_ids (node )
58
55
)
59
56
return cls (node , selection , None , ())
57
+ if isinstance (node , nodes .ProjectionNode ):
58
+ return cls .from_node_span (node .child , target ).project (node .assignments )
59
+ elif isinstance (node , nodes .FilterNode ):
60
+ return cls .from_node_span (node .child , target ).filter (node .predicate )
61
+ elif isinstance (node , nodes .ReversedNode ):
62
+ return cls .from_node_span (node .child , target ).reverse ()
63
+ elif isinstance (node , nodes .OrderByNode ):
64
+ return cls .from_node_span (node .child , target ).order_with (node .by )
65
+ else :
66
+ raise ValueError (f"Cannot rewrite node { node } " )
60
67
61
68
@property
62
69
def column_lookup (self ) -> Mapping [str , scalar_exprs .Expression ]:
@@ -98,9 +105,10 @@ def order_with(self, by: Tuple[order.OrderingExpression, ...]):
98
105
self .root , self .columns , self .predicate , new_ordering , self .reverse_root
99
106
)
100
107
101
- def can_join (
108
+ def can_merge (
102
109
self , right : SquashedSelect , join_def : join_defs .JoinDefinition
103
110
) -> bool :
111
+ """Determines whether the two selections can be merged into a single selection."""
104
112
if join_def .type == "cross" :
105
113
# Cannot convert cross join to projection
106
114
return False
@@ -116,14 +124,14 @@ def can_join(
116
124
return False
117
125
return True
118
126
119
- def maybe_merge (
127
+ def merge (
120
128
self ,
121
129
right : SquashedSelect ,
122
130
join_type : join_defs .JoinType ,
123
131
mappings : Tuple [join_defs .JoinColumnMapping , ...],
124
- ) -> Optional [ SquashedSelect ] :
132
+ ) -> SquashedSelect :
125
133
if self .root != right .root :
126
- return None
134
+ raise ValueError ( "Cannot merge expressions with different roots" )
127
135
# Mask columns and remap names to expected schema
128
136
lselection = self .columns
129
137
rselection = right .columns
@@ -196,28 +204,40 @@ def expand(self) -> nodes.BigFrameNode:
196
204
return nodes .ProjectionNode (child = root , assignments = self .columns )
197
205
198
206
199
- def maybe_squash_projection (node : nodes .BigFrameNode ) -> nodes .BigFrameNode :
200
- if isinstance (node , nodes .ProjectionNode ) and isinstance (
201
- node .child , nodes .ProjectionNode
202
- ):
203
- # Conservative approach, only squash consecutive projections, even though could also squash filters, reorderings
204
- return SquashedSelect .from_node (node , projections_only = True ).expand ()
205
- return node
206
-
207
-
208
207
def maybe_rewrite_join (join_node : nodes .JoinNode ) -> nodes .BigFrameNode :
209
- left_side = SquashedSelect .from_node (join_node .left_child )
210
- right_side = SquashedSelect .from_node (join_node .right_child )
211
- if left_side .can_join (right_side , join_node .join ):
212
- merged = left_side .maybe_merge (
208
+ rewrite_common_node = common_selection_root (
209
+ join_node .left_child , join_node .right_child
210
+ )
211
+ if rewrite_common_node is None :
212
+ return join_node
213
+ left_side = SquashedSelect .from_node_span (join_node .left_child , rewrite_common_node )
214
+ right_side = SquashedSelect .from_node_span (
215
+ join_node .right_child , rewrite_common_node
216
+ )
217
+ if left_side .can_merge (right_side , join_node .join ):
218
+ return left_side .merge (
213
219
right_side , join_node .join .type , join_node .join .mappings
214
- )
220
+ ).expand ()
221
+ return join_node
222
+
223
+
224
+ def join_as_projection (
225
+ l_node : nodes .BigFrameNode ,
226
+ r_node : nodes .BigFrameNode ,
227
+ mappings : Tuple [join_defs .JoinColumnMapping , ...],
228
+ how : join_defs .JoinType ,
229
+ ) -> Optional [nodes .BigFrameNode ]:
230
+ rewrite_common_node = common_selection_root (l_node , r_node )
231
+ if rewrite_common_node is not None :
232
+ left_side = SquashedSelect .from_node_span (l_node , rewrite_common_node )
233
+ right_side = SquashedSelect .from_node_span (r_node , rewrite_common_node )
234
+ merged = left_side .merge (right_side , how , mappings )
215
235
assert (
216
236
merged is not None
217
237
),
"Couldn't merge nodes. This shouldn't happen. Please share full stacktrace with the BigQuery DataFrames team at [email protected] ."
218
238
return merged .expand ()
219
239
else :
220
- return join_node
240
+ return None
221
241
222
242
223
243
def remap_names (
@@ -311,3 +331,25 @@ def get_node_column_ids(node: nodes.BigFrameNode) -> Tuple[str, ...]:
311
331
import bigframes .core
312
332
313
333
return tuple (bigframes .core .ArrayValue (node ).column_ids )
334
+
335
+
336
+ def common_selection_root (
337
+ l_tree : nodes .BigFrameNode , r_tree : nodes .BigFrameNode
338
+ ) -> Optional [nodes .BigFrameNode ]:
339
+ """Find common subtree between join subtrees"""
340
+ l_node = l_tree
341
+ l_nodes : set [nodes .BigFrameNode ] = set ()
342
+ while isinstance (l_node , REWRITABLE_NODE_TYPES ):
343
+ l_nodes .add (l_node )
344
+ l_node = l_node .child
345
+ l_nodes .add (l_node )
346
+
347
+ r_node = r_tree
348
+ while isinstance (r_node , REWRITABLE_NODE_TYPES ):
349
+ if r_node in l_nodes :
350
+ return r_node
351
+ r_node = r_node .child
352
+
353
+ if r_node in l_nodes :
354
+ return r_node
355
+ return None
0 commit comments