2626
2727Selection = Tuple [Tuple [scalar_exprs .Expression , str ], ...]
2828
29+ REWRITABLE_NODE_TYPES = (
30+ nodes .ProjectionNode ,
31+ nodes .FilterNode ,
32+ nodes .ReversedNode ,
33+ nodes .OrderByNode ,
34+ )
35+
2936
3037@dataclasses .dataclass (frozen = True )
3138class SquashedSelect :
32- """Squash together as many nodes as possible , separating out the projection, filter and reordering expressions."""
39+ """Squash nodes together until target node , separating out the projection, filter and reordering expressions."""
3340
3441 root : nodes .BigFrameNode
3542 columns : Tuple [Tuple [scalar_exprs .Expression , str ], ...]
@@ -38,25 +45,25 @@ class SquashedSelect:
3845 reverse_root : bool = False
3946
4047 @classmethod
41- def from_node (
42- cls , node : nodes .BigFrameNode , projections_only : bool = False
48+ def from_node_span (
49+ cls , node : nodes .BigFrameNode , target : nodes . BigFrameNode
4350 ) -> SquashedSelect :
44- if isinstance (node , nodes .ProjectionNode ):
45- return cls .from_node (node .child , projections_only = projections_only ).project (
46- node .assignments
47- )
48- elif not projections_only and isinstance (node , nodes .FilterNode ):
49- return cls .from_node (node .child ).filter (node .predicate )
50- elif not projections_only and isinstance (node , nodes .ReversedNode ):
51- return cls .from_node (node .child ).reverse ()
52- elif not projections_only and isinstance (node , nodes .OrderByNode ):
53- return cls .from_node (node .child ).order_with (node .by )
54- else :
51+ if node == target :
5552 selection = tuple (
5653 (scalar_exprs .UnboundVariableExpression (id ), id )
5754 for id in get_node_column_ids (node )
5855 )
5956 return cls (node , selection , None , ())
57+ if isinstance (node , nodes .ProjectionNode ):
58+ return cls .from_node_span (node .child , target ).project (node .assignments )
59+ elif isinstance (node , nodes .FilterNode ):
60+ return cls .from_node_span (node .child , target ).filter (node .predicate )
61+ elif isinstance (node , nodes .ReversedNode ):
62+ return cls .from_node_span (node .child , target ).reverse ()
63+ elif isinstance (node , nodes .OrderByNode ):
64+ return cls .from_node_span (node .child , target ).order_with (node .by )
65+ else :
66+ raise ValueError (f"Cannot rewrite node { node } " )
6067
6168 @property
6269 def column_lookup (self ) -> Mapping [str , scalar_exprs .Expression ]:
@@ -98,9 +105,10 @@ def order_with(self, by: Tuple[order.OrderingExpression, ...]):
98105 self .root , self .columns , self .predicate , new_ordering , self .reverse_root
99106 )
100107
101- def can_join (
108+ def can_merge (
102109 self , right : SquashedSelect , join_def : join_defs .JoinDefinition
103110 ) -> bool :
111+ """Determines whether the two selections can be merged into a single selection."""
104112 if join_def .type == "cross" :
105113 # Cannot convert cross join to projection
106114 return False
@@ -116,14 +124,14 @@ def can_join(
116124 return False
117125 return True
118126
119- def maybe_merge (
127+ def merge (
120128 self ,
121129 right : SquashedSelect ,
122130 join_type : join_defs .JoinType ,
123131 mappings : Tuple [join_defs .JoinColumnMapping , ...],
124- ) -> Optional [ SquashedSelect ] :
132+ ) -> SquashedSelect :
125133 if self .root != right .root :
126- return None
134+ raise ValueError ( "Cannot merge expressions with different roots" )
127135 # Mask columns and remap names to expected schema
128136 lselection = self .columns
129137 rselection = right .columns
@@ -196,28 +204,40 @@ def expand(self) -> nodes.BigFrameNode:
196204 return nodes .ProjectionNode (child = root , assignments = self .columns )
197205
198206
199- def maybe_squash_projection (node : nodes .BigFrameNode ) -> nodes .BigFrameNode :
200- if isinstance (node , nodes .ProjectionNode ) and isinstance (
201- node .child , nodes .ProjectionNode
202- ):
203- # Conservative approach, only squash consecutive projections, even though could also squash filters, reorderings
204- return SquashedSelect .from_node (node , projections_only = True ).expand ()
205- return node
206-
207-
208207def maybe_rewrite_join (join_node : nodes .JoinNode ) -> nodes .BigFrameNode :
209- left_side = SquashedSelect .from_node (join_node .left_child )
210- right_side = SquashedSelect .from_node (join_node .right_child )
211- if left_side .can_join (right_side , join_node .join ):
212- merged = left_side .maybe_merge (
208+ rewrite_common_node = common_selection_root (
209+ join_node .left_child , join_node .right_child
210+ )
211+ if rewrite_common_node is None :
212+ return join_node
213+ left_side = SquashedSelect .from_node_span (join_node .left_child , rewrite_common_node )
214+ right_side = SquashedSelect .from_node_span (
215+ join_node .right_child , rewrite_common_node
216+ )
217+ if left_side .can_merge (right_side , join_node .join ):
218+ return left_side .merge (
213219 right_side , join_node .join .type , join_node .join .mappings
214- )
220+ ).expand ()
221+ return join_node
222+
223+
224+ def join_as_projection (
225+ l_node : nodes .BigFrameNode ,
226+ r_node : nodes .BigFrameNode ,
227+ mappings : Tuple [join_defs .JoinColumnMapping , ...],
228+ how : join_defs .JoinType ,
229+ ) -> Optional [nodes .BigFrameNode ]:
230+ rewrite_common_node = common_selection_root (l_node , r_node )
231+ if rewrite_common_node is not None :
232+ left_side = SquashedSelect .from_node_span (l_node , rewrite_common_node )
233+ right_side = SquashedSelect .from_node_span (r_node , rewrite_common_node )
234+ merged = left_side .merge (right_side , how , mappings )
215235 assert (
216236 merged is not None
217237 ),
"Couldn't merge nodes. This shouldn't happen. Please share full stacktrace with the BigQuery DataFrames team at [email protected] ." 218238 return merged .expand ()
219239 else :
220- return join_node
240+ return None
221241
222242
223243def remap_names (
@@ -311,3 +331,25 @@ def get_node_column_ids(node: nodes.BigFrameNode) -> Tuple[str, ...]:
311331 import bigframes .core
312332
313333 return tuple (bigframes .core .ArrayValue (node ).column_ids )
334+
335+
336+ def common_selection_root (
337+ l_tree : nodes .BigFrameNode , r_tree : nodes .BigFrameNode
338+ ) -> Optional [nodes .BigFrameNode ]:
339+ """Find common subtree between join subtrees"""
340+ l_node = l_tree
341+ l_nodes : set [nodes .BigFrameNode ] = set ()
342+ while isinstance (l_node , REWRITABLE_NODE_TYPES ):
343+ l_nodes .add (l_node )
344+ l_node = l_node .child
345+ l_nodes .add (l_node )
346+
347+ r_node = r_tree
348+ while isinstance (r_node , REWRITABLE_NODE_TYPES ):
349+ if r_node in l_nodes :
350+ return r_node
351+ r_node = r_node .child
352+
353+ if r_node in l_nodes :
354+ return r_node
355+ return None
0 commit comments