1414from __future__ import annotations
1515
1616import dataclasses
17- from typing import Optional , Tuple
17+ from typing import Iterable , Optional , Tuple
1818
1919import bigframes .core .expression
2020import bigframes .core .guid
3030 bigframes .core .nodes .WindowOpNode ,
3131 bigframes .core .nodes .PromoteOffsetsNode ,
3232)
33+ # Combination of selects and additive nodes can be merged as an explicit keyless "row join"
34+ ALIGNABLE_NODES = (
35+ * ADDITIVE_NODES ,
36+ bigframes .core .nodes .SelectionNode ,
37+ )
3338
3439
3540@dataclasses .dataclass (frozen = True )
@@ -70,49 +75,29 @@ def get_expression_spec(
7075 bigframes .core .nodes .PromoteOffsetsNode ,
7176 ),
7277 ):
73- # we don't yet have a way of normalizing window ops into a ExpressionSpec, which only
74- # handles normalizing scalar expressions at the moment.
75- pass
78+ if set (expression .column_references ).isdisjoint (
79+ field .id for field in curr_node .added_fields
80+ ):
81+ # we don't yet have a way of normalizing window ops into a ExpressionSpec, which only
82+ # handles normalizing scalar expressions at the moment.
83+ pass
84+ else :
85+ return ExpressionSpec (expression , curr_node )
7686 else :
7787 return ExpressionSpec (expression , curr_node )
7888 curr_node = curr_node .child
7989
8090
81- def _linearize_trees (
82- base_tree : bigframes .core .nodes .BigFrameNode ,
83- append_tree : bigframes .core .nodes .BigFrameNode ,
84- ) -> bigframes .core .nodes .BigFrameNode :
85- """Linearize two divergent tree who only diverge through different additive nodes."""
86- assert append_tree .projection_base == base_tree .projection_base
87- # base case: append tree does not have any additive nodes to linearize
88- if append_tree == append_tree .projection_base :
89- return base_tree
90- else :
91- assert isinstance (append_tree , ADDITIVE_NODES )
92- return append_tree .replace_child (_linearize_trees (base_tree , append_tree .child ))
93-
94-
95- def combine_nodes (
96- l_node : bigframes .core .nodes .BigFrameNode ,
97- r_node : bigframes .core .nodes .BigFrameNode ,
98- ) -> bigframes .core .nodes .BigFrameNode :
99- assert l_node .projection_base == r_node .projection_base
100- l_node , l_selection = pull_up_selection (l_node )
101- r_node , r_selection = pull_up_selection (
102- r_node , rename_vars = True
103- ) # Rename only right vars to avoid collisions with left vars
104- combined_selection = (* l_selection , * r_selection )
105- merged_node = _linearize_trees (l_node , r_node )
106- return bigframes .core .nodes .SelectionNode (merged_node , combined_selection )
107-
108-
109- def try_join_as_projection (
91+ def try_row_join (
11092 l_node : bigframes .core .nodes .BigFrameNode ,
11193 r_node : bigframes .core .nodes .BigFrameNode ,
11294 join_keys : Tuple [Tuple [str , str ], ...],
11395) -> Optional [bigframes .core .nodes .BigFrameNode ]:
11496 """Joins the two nodes"""
115- if l_node .projection_base != r_node .projection_base :
97+ divergent_node = first_shared_descendent (
98+ l_node , r_node , descendable_types = ALIGNABLE_NODES
99+ )
100+ if divergent_node is None :
116101 return None
117102 # check join keys are equivalent by normalizing the expressions as much as posisble
118103 # instead of just comparing ids
@@ -124,11 +109,35 @@ def try_join_as_projection(
124109 r_node , right_id
125110 ):
126111 return None
127- return combine_nodes (l_node , r_node )
112+
113+ l_node , l_selection = pull_up_selection (l_node , stop = divergent_node )
114+ r_node , r_selection = pull_up_selection (
115+ r_node , stop = divergent_node , rename_vars = True
116+ ) # Rename only right vars to avoid collisions with left vars
117+ combined_selection = (* l_selection , * r_selection )
118+
119+ def _linearize_trees (
120+ base_tree : bigframes .core .nodes .BigFrameNode ,
121+ append_tree : bigframes .core .nodes .BigFrameNode ,
122+ ) -> bigframes .core .nodes .BigFrameNode :
123+ """Linearize two divergent tree who only diverge through different additive nodes."""
124+ # base case: append tree does not have any divergent nodes to linearize
125+ if append_tree == divergent_node :
126+ return base_tree
127+ else :
128+ assert isinstance (append_tree , ADDITIVE_NODES )
129+ return append_tree .replace_child (
130+ _linearize_trees (base_tree , append_tree .child )
131+ )
132+
133+ merged_node = _linearize_trees (l_node , r_node )
134+ return bigframes .core .nodes .SelectionNode (merged_node , combined_selection )
128135
129136
130137def pull_up_selection (
131- node : bigframes .core .nodes .BigFrameNode , rename_vars : bool = False
138+ node : bigframes .core .nodes .BigFrameNode ,
139+ stop : bigframes .core .nodes .BigFrameNode ,
140+ rename_vars : bool = False ,
132141) -> Tuple [
133142 bigframes .core .nodes .BigFrameNode ,
134143 Tuple [
@@ -147,14 +156,14 @@ def pull_up_selection(
147156 Returns:
148157 BigFrameNode, Selections
149158 """
150- if node == node . projection_base : # base case
159+ if node == stop : # base case
151160 return node , tuple (
152161 (bigframes .core .expression .DerefOp (field .id ), field .id )
153162 for field in node .fields
154163 )
155164 assert isinstance (node , (bigframes .core .nodes .SelectionNode , * ADDITIVE_NODES ))
156165 child_node , child_selections = pull_up_selection (
157- node .child , rename_vars = rename_vars
166+ node .child , stop , rename_vars = rename_vars
158167 )
159168 mapping = {out : ref .id for ref , out in child_selections }
160169 if isinstance (node , ADDITIVE_NODES ):
@@ -188,3 +197,30 @@ def pull_up_selection(
188197 )
189198 return child_node , new_selection
190199 raise ValueError (f"Couldn't pull up select from node: { node } " )
200+
201+
202+ ## Traversal helpers
203+ def first_shared_descendent (
204+ left : bigframes .core .nodes .BigFrameNode ,
205+ right : bigframes .core .nodes .BigFrameNode ,
206+ descendable_types : Tuple [type [bigframes .core .nodes .UnaryNode ], ...],
207+ ) -> Optional [bigframes .core .nodes .BigFrameNode ]:
208+ l_path = tuple (descend (left , descendable_types ))
209+ r_path = tuple (descend (right , descendable_types ))
210+ if l_path [- 1 ] != r_path [- 1 ]:
211+ return None
212+
213+ for l_node , r_node in zip (l_path [- len (r_path ) :], r_path [- len (l_path ) :]):
214+ if l_node == r_node :
215+ return l_node
216+ # should be impossible, as l_path[-1] == r_path[-1]
217+ raise ValueError ()
218+
219+
220+ def descend (
221+ root : bigframes .core .nodes .BigFrameNode ,
222+ descendable_types : Tuple [type [bigframes .core .nodes .UnaryNode ], ...],
223+ ) -> Iterable [bigframes .core .nodes .BigFrameNode ]:
224+ yield root
225+ if isinstance (root , descendable_types ):
226+ yield from descend (root .child , descendable_types )
0 commit comments