14
14
from __future__ import annotations
15
15
16
16
import dataclasses
17
- from typing import Optional , Tuple
17
+ from typing import Iterable , Optional , Tuple
18
18
19
19
import bigframes .core .expression
20
20
import bigframes .core .guid
30
30
bigframes .core .nodes .WindowOpNode ,
31
31
bigframes .core .nodes .PromoteOffsetsNode ,
32
32
)
33
+ # Combination of selects and additive nodes can be merged as an explicit keyless "row join"
34
+ ALIGNABLE_NODES = (
35
+ * ADDITIVE_NODES ,
36
+ bigframes .core .nodes .SelectionNode ,
37
+ )
33
38
34
39
35
40
@dataclasses .dataclass (frozen = True )
@@ -70,49 +75,29 @@ def get_expression_spec(
70
75
bigframes .core .nodes .PromoteOffsetsNode ,
71
76
),
72
77
):
73
- # we don't yet have a way of normalizing window ops into a ExpressionSpec, which only
74
- # handles normalizing scalar expressions at the moment.
75
- pass
78
+ if set (expression .column_references ).isdisjoint (
79
+ field .id for field in curr_node .added_fields
80
+ ):
81
+ # we don't yet have a way of normalizing window ops into a ExpressionSpec, which only
82
+ # handles normalizing scalar expressions at the moment.
83
+ pass
84
+ else :
85
+ return ExpressionSpec (expression , curr_node )
76
86
else :
77
87
return ExpressionSpec (expression , curr_node )
78
88
curr_node = curr_node .child
79
89
80
90
81
- def _linearize_trees (
82
- base_tree : bigframes .core .nodes .BigFrameNode ,
83
- append_tree : bigframes .core .nodes .BigFrameNode ,
84
- ) -> bigframes .core .nodes .BigFrameNode :
85
- """Linearize two divergent tree who only diverge through different additive nodes."""
86
- assert append_tree .projection_base == base_tree .projection_base
87
- # base case: append tree does not have any additive nodes to linearize
88
- if append_tree == append_tree .projection_base :
89
- return base_tree
90
- else :
91
- assert isinstance (append_tree , ADDITIVE_NODES )
92
- return append_tree .replace_child (_linearize_trees (base_tree , append_tree .child ))
93
-
94
-
95
- def combine_nodes (
96
- l_node : bigframes .core .nodes .BigFrameNode ,
97
- r_node : bigframes .core .nodes .BigFrameNode ,
98
- ) -> bigframes .core .nodes .BigFrameNode :
99
- assert l_node .projection_base == r_node .projection_base
100
- l_node , l_selection = pull_up_selection (l_node )
101
- r_node , r_selection = pull_up_selection (
102
- r_node , rename_vars = True
103
- ) # Rename only right vars to avoid collisions with left vars
104
- combined_selection = (* l_selection , * r_selection )
105
- merged_node = _linearize_trees (l_node , r_node )
106
- return bigframes .core .nodes .SelectionNode (merged_node , combined_selection )
107
-
108
-
109
- def try_join_as_projection (
91
+ def try_row_join (
110
92
l_node : bigframes .core .nodes .BigFrameNode ,
111
93
r_node : bigframes .core .nodes .BigFrameNode ,
112
94
join_keys : Tuple [Tuple [str , str ], ...],
113
95
) -> Optional [bigframes .core .nodes .BigFrameNode ]:
114
96
"""Joins the two nodes"""
115
- if l_node .projection_base != r_node .projection_base :
97
+ divergent_node = first_shared_descendent (
98
+ l_node , r_node , descendable_types = ALIGNABLE_NODES
99
+ )
100
+ if divergent_node is None :
116
101
return None
117
102
# check join keys are equivalent by normalizing the expressions as much as posisble
118
103
# instead of just comparing ids
@@ -124,11 +109,35 @@ def try_join_as_projection(
124
109
r_node , right_id
125
110
):
126
111
return None
127
- return combine_nodes (l_node , r_node )
112
+
113
+ l_node , l_selection = pull_up_selection (l_node , stop = divergent_node )
114
+ r_node , r_selection = pull_up_selection (
115
+ r_node , stop = divergent_node , rename_vars = True
116
+ ) # Rename only right vars to avoid collisions with left vars
117
+ combined_selection = (* l_selection , * r_selection )
118
+
119
+ def _linearize_trees (
120
+ base_tree : bigframes .core .nodes .BigFrameNode ,
121
+ append_tree : bigframes .core .nodes .BigFrameNode ,
122
+ ) -> bigframes .core .nodes .BigFrameNode :
123
+ """Linearize two divergent tree who only diverge through different additive nodes."""
124
+ # base case: append tree does not have any divergent nodes to linearize
125
+ if append_tree == divergent_node :
126
+ return base_tree
127
+ else :
128
+ assert isinstance (append_tree , ADDITIVE_NODES )
129
+ return append_tree .replace_child (
130
+ _linearize_trees (base_tree , append_tree .child )
131
+ )
132
+
133
+ merged_node = _linearize_trees (l_node , r_node )
134
+ return bigframes .core .nodes .SelectionNode (merged_node , combined_selection )
128
135
129
136
130
137
def pull_up_selection (
131
- node : bigframes .core .nodes .BigFrameNode , rename_vars : bool = False
138
+ node : bigframes .core .nodes .BigFrameNode ,
139
+ stop : bigframes .core .nodes .BigFrameNode ,
140
+ rename_vars : bool = False ,
132
141
) -> Tuple [
133
142
bigframes .core .nodes .BigFrameNode ,
134
143
Tuple [
@@ -147,14 +156,14 @@ def pull_up_selection(
147
156
Returns:
148
157
BigFrameNode, Selections
149
158
"""
150
- if node == node . projection_base : # base case
159
+ if node == stop : # base case
151
160
return node , tuple (
152
161
(bigframes .core .expression .DerefOp (field .id ), field .id )
153
162
for field in node .fields
154
163
)
155
164
assert isinstance (node , (bigframes .core .nodes .SelectionNode , * ADDITIVE_NODES ))
156
165
child_node , child_selections = pull_up_selection (
157
- node .child , rename_vars = rename_vars
166
+ node .child , stop , rename_vars = rename_vars
158
167
)
159
168
mapping = {out : ref .id for ref , out in child_selections }
160
169
if isinstance (node , ADDITIVE_NODES ):
@@ -188,3 +197,30 @@ def pull_up_selection(
188
197
)
189
198
return child_node , new_selection
190
199
raise ValueError (f"Couldn't pull up select from node: { node } " )
200
+
201
+
202
+ ## Traversal helpers
203
+ def first_shared_descendent (
204
+ left : bigframes .core .nodes .BigFrameNode ,
205
+ right : bigframes .core .nodes .BigFrameNode ,
206
+ descendable_types : Tuple [type [bigframes .core .nodes .UnaryNode ], ...],
207
+ ) -> Optional [bigframes .core .nodes .BigFrameNode ]:
208
+ l_path = tuple (descend (left , descendable_types ))
209
+ r_path = tuple (descend (right , descendable_types ))
210
+ if l_path [- 1 ] != r_path [- 1 ]:
211
+ return None
212
+
213
+ for l_node , r_node in zip (l_path [- len (r_path ) :], r_path [- len (l_path ) :]):
214
+ if l_node == r_node :
215
+ return l_node
216
+ # should be impossible, as l_path[-1] == r_path[-1]
217
+ raise ValueError ()
218
+
219
+
220
+ def descend (
221
+ root : bigframes .core .nodes .BigFrameNode ,
222
+ descendable_types : Tuple [type [bigframes .core .nodes .UnaryNode ], ...],
223
+ ) -> Iterable [bigframes .core .nodes .BigFrameNode ]:
224
+ yield root
225
+ if isinstance (root , descendable_types ):
226
+ yield from descend (root .child , descendable_types )
0 commit comments