Skip to content

Commit f200f68

Browse files
refactor: Share descendent logic between old and new aligner (#1199)
1 parent 1021d57 commit f200f68

File tree

5 files changed

+88
-90
lines changed

5 files changed

+88
-90
lines changed

bigframes/core/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -446,7 +446,7 @@ def try_row_join(
446446
other_node, r_mapping = self.prepare_join_names(other)
447447
import bigframes.core.rewrite
448448

449-
result_node = bigframes.core.rewrite.try_join_as_projection(
449+
result_node = bigframes.core.rewrite.try_row_join(
450450
self.node, other_node, conditions
451451
)
452452
if result_node is None:

bigframes/core/nodes.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,6 @@ def child_nodes(self) -> typing.Sequence[BigFrameNode]:
8383
"""Direct children of this node"""
8484
return tuple([])
8585

86-
@property
87-
def projection_base(self) -> BigFrameNode:
88-
return self
89-
9086
@property
9187
@abc.abstractmethod
9288
def row_count(self) -> typing.Optional[int]:
@@ -918,10 +914,6 @@ def row_count(self) -> Optional[int]:
918914
def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
919915
return (self.col_id,)
920916

921-
@property
922-
def projection_base(self) -> BigFrameNode:
923-
return self.child.projection_base
924-
925917
@property
926918
def added_fields(self) -> Tuple[Field, ...]:
927919
return (Field(self.col_id, bigframes.dtypes.INT_DTYPE),)
@@ -1095,10 +1087,6 @@ def variables_introduced(self) -> int:
10951087
def defines_namespace(self) -> bool:
10961088
return True
10971089

1098-
@property
1099-
def projection_base(self) -> BigFrameNode:
1100-
return self.child.projection_base
1101-
11021090
@property
11031091
def row_count(self) -> Optional[int]:
11041092
return self.child.row_count
@@ -1173,10 +1161,6 @@ def variables_introduced(self) -> int:
11731161
def row_count(self) -> Optional[int]:
11741162
return self.child.row_count
11751163

1176-
@property
1177-
def projection_base(self) -> BigFrameNode:
1178-
return self.child.projection_base
1179-
11801164
@property
11811165
def node_defined_ids(self) -> Tuple[bfet_ids.ColumnId, ...]:
11821166
return tuple(id for _, id in self.assignments)
@@ -1361,10 +1345,6 @@ def fields(self) -> Iterable[Field]:
13611345
def variables_introduced(self) -> int:
13621346
return 1
13631347

1364-
@property
1365-
def projection_base(self) -> BigFrameNode:
1366-
return self.child.projection_base
1367-
13681348
@property
13691349
def added_fields(self) -> Tuple[Field, ...]:
13701350
return (self.added_field,)

bigframes/core/rewrite/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
# limitations under the License.
1414

1515
from bigframes.core.rewrite.identifiers import remap_variables
16-
from bigframes.core.rewrite.implicit_align import try_join_as_projection
16+
from bigframes.core.rewrite.implicit_align import try_row_join
1717
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
1818
from bigframes.core.rewrite.slices import pullup_limit_from_slice, replace_slice_ops
1919

2020
__all__ = [
2121
"legacy_join_as_projection",
22-
"try_join_as_projection",
22+
"try_row_join",
2323
"replace_slice_ops",
2424
"pullup_limit_from_slice",
2525
"remap_variables",

bigframes/core/rewrite/implicit_align.py

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
import dataclasses
17-
from typing import Optional, Tuple
17+
from typing import Iterable, Optional, Tuple
1818

1919
import bigframes.core.expression
2020
import bigframes.core.guid
@@ -30,6 +30,11 @@
3030
bigframes.core.nodes.WindowOpNode,
3131
bigframes.core.nodes.PromoteOffsetsNode,
3232
)
33+
# Combination of selects and additive nodes can be merged as an explicit keyless "row join"
34+
ALIGNABLE_NODES = (
35+
*ADDITIVE_NODES,
36+
bigframes.core.nodes.SelectionNode,
37+
)
3338

3439

3540
@dataclasses.dataclass(frozen=True)
@@ -70,49 +75,29 @@ def get_expression_spec(
7075
bigframes.core.nodes.PromoteOffsetsNode,
7176
),
7277
):
73-
# we don't yet have a way of normalizing window ops into a ExpressionSpec, which only
74-
# handles normalizing scalar expressions at the moment.
75-
pass
78+
if set(expression.column_references).isdisjoint(
79+
field.id for field in curr_node.added_fields
80+
):
81+
# we don't yet have a way of normalizing window ops into a ExpressionSpec, which only
82+
# handles normalizing scalar expressions at the moment.
83+
pass
84+
else:
85+
return ExpressionSpec(expression, curr_node)
7686
else:
7787
return ExpressionSpec(expression, curr_node)
7888
curr_node = curr_node.child
7989

8090

81-
def _linearize_trees(
82-
base_tree: bigframes.core.nodes.BigFrameNode,
83-
append_tree: bigframes.core.nodes.BigFrameNode,
84-
) -> bigframes.core.nodes.BigFrameNode:
85-
"""Linearize two divergent tree who only diverge through different additive nodes."""
86-
assert append_tree.projection_base == base_tree.projection_base
87-
# base case: append tree does not have any additive nodes to linearize
88-
if append_tree == append_tree.projection_base:
89-
return base_tree
90-
else:
91-
assert isinstance(append_tree, ADDITIVE_NODES)
92-
return append_tree.replace_child(_linearize_trees(base_tree, append_tree.child))
93-
94-
95-
def combine_nodes(
96-
l_node: bigframes.core.nodes.BigFrameNode,
97-
r_node: bigframes.core.nodes.BigFrameNode,
98-
) -> bigframes.core.nodes.BigFrameNode:
99-
assert l_node.projection_base == r_node.projection_base
100-
l_node, l_selection = pull_up_selection(l_node)
101-
r_node, r_selection = pull_up_selection(
102-
r_node, rename_vars=True
103-
) # Rename only right vars to avoid collisions with left vars
104-
combined_selection = (*l_selection, *r_selection)
105-
merged_node = _linearize_trees(l_node, r_node)
106-
return bigframes.core.nodes.SelectionNode(merged_node, combined_selection)
107-
108-
109-
def try_join_as_projection(
91+
def try_row_join(
11092
l_node: bigframes.core.nodes.BigFrameNode,
11193
r_node: bigframes.core.nodes.BigFrameNode,
11294
join_keys: Tuple[Tuple[str, str], ...],
11395
) -> Optional[bigframes.core.nodes.BigFrameNode]:
11496
"""Joins the two nodes"""
115-
if l_node.projection_base != r_node.projection_base:
97+
divergent_node = first_shared_descendent(
98+
l_node, r_node, descendable_types=ALIGNABLE_NODES
99+
)
100+
if divergent_node is None:
116101
return None
117102
# check join keys are equivalent by normalizing the expressions as much as posisble
118103
# instead of just comparing ids
@@ -124,11 +109,35 @@ def try_join_as_projection(
124109
r_node, right_id
125110
):
126111
return None
127-
return combine_nodes(l_node, r_node)
112+
113+
l_node, l_selection = pull_up_selection(l_node, stop=divergent_node)
114+
r_node, r_selection = pull_up_selection(
115+
r_node, stop=divergent_node, rename_vars=True
116+
) # Rename only right vars to avoid collisions with left vars
117+
combined_selection = (*l_selection, *r_selection)
118+
119+
def _linearize_trees(
120+
base_tree: bigframes.core.nodes.BigFrameNode,
121+
append_tree: bigframes.core.nodes.BigFrameNode,
122+
) -> bigframes.core.nodes.BigFrameNode:
123+
"""Linearize two divergent tree who only diverge through different additive nodes."""
124+
# base case: append tree does not have any divergent nodes to linearize
125+
if append_tree == divergent_node:
126+
return base_tree
127+
else:
128+
assert isinstance(append_tree, ADDITIVE_NODES)
129+
return append_tree.replace_child(
130+
_linearize_trees(base_tree, append_tree.child)
131+
)
132+
133+
merged_node = _linearize_trees(l_node, r_node)
134+
return bigframes.core.nodes.SelectionNode(merged_node, combined_selection)
128135

129136

130137
def pull_up_selection(
131-
node: bigframes.core.nodes.BigFrameNode, rename_vars: bool = False
138+
node: bigframes.core.nodes.BigFrameNode,
139+
stop: bigframes.core.nodes.BigFrameNode,
140+
rename_vars: bool = False,
132141
) -> Tuple[
133142
bigframes.core.nodes.BigFrameNode,
134143
Tuple[
@@ -147,14 +156,14 @@ def pull_up_selection(
147156
Returns:
148157
BigFrameNode, Selections
149158
"""
150-
if node == node.projection_base: # base case
159+
if node == stop: # base case
151160
return node, tuple(
152161
(bigframes.core.expression.DerefOp(field.id), field.id)
153162
for field in node.fields
154163
)
155164
assert isinstance(node, (bigframes.core.nodes.SelectionNode, *ADDITIVE_NODES))
156165
child_node, child_selections = pull_up_selection(
157-
node.child, rename_vars=rename_vars
166+
node.child, stop, rename_vars=rename_vars
158167
)
159168
mapping = {out: ref.id for ref, out in child_selections}
160169
if isinstance(node, ADDITIVE_NODES):
@@ -188,3 +197,30 @@ def pull_up_selection(
188197
)
189198
return child_node, new_selection
190199
raise ValueError(f"Couldn't pull up select from node: {node}")
200+
201+
202+
## Traversal helpers
203+
def first_shared_descendent(
204+
left: bigframes.core.nodes.BigFrameNode,
205+
right: bigframes.core.nodes.BigFrameNode,
206+
descendable_types: Tuple[type[bigframes.core.nodes.UnaryNode], ...],
207+
) -> Optional[bigframes.core.nodes.BigFrameNode]:
208+
l_path = tuple(descend(left, descendable_types))
209+
r_path = tuple(descend(right, descendable_types))
210+
if l_path[-1] != r_path[-1]:
211+
return None
212+
213+
for l_node, r_node in zip(l_path[-len(r_path) :], r_path[-len(l_path) :]):
214+
if l_node == r_node:
215+
return l_node
216+
# should be impossible, as l_path[-1] == r_path[-1]
217+
raise ValueError()
218+
219+
220+
def descend(
221+
root: bigframes.core.nodes.BigFrameNode,
222+
descendable_types: Tuple[type[bigframes.core.nodes.UnaryNode], ...],
223+
) -> Iterable[bigframes.core.nodes.BigFrameNode]:
224+
yield root
225+
if isinstance(root, descendable_types):
226+
yield from descend(root.child, descendable_types)

bigframes/core/rewrite/legacy_align.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,17 @@
2323
import bigframes.core.join_def as join_defs
2424
import bigframes.core.nodes as nodes
2525
import bigframes.core.ordering as order
26+
import bigframes.core.rewrite.implicit_align
2627
import bigframes.operations as ops
2728

2829
Selection = Tuple[Tuple[scalar_exprs.Expression, ids.ColumnId], ...]
2930

30-
REWRITABLE_NODE_TYPES = (
31-
nodes.SelectionNode,
32-
nodes.ProjectionNode,
33-
nodes.FilterNode,
34-
nodes.ReversedNode,
35-
nodes.OrderByNode,
31+
LEGACY_REWRITER_NODES = (
32+
bigframes.core.nodes.ProjectionNode,
33+
bigframes.core.nodes.SelectionNode,
34+
bigframes.core.nodes.ReversedNode,
35+
bigframes.core.nodes.OrderByNode,
36+
bigframes.core.nodes.FilterNode,
3637
)
3738

3839

@@ -51,9 +52,7 @@ def from_node_span(
5152
cls, node: nodes.BigFrameNode, target: nodes.BigFrameNode
5253
) -> SquashedSelect:
5354
if node == target:
54-
selection = tuple(
55-
(scalar_exprs.DerefOp(id), id) for id in get_node_column_ids(node)
56-
)
55+
selection = tuple((scalar_exprs.DerefOp(id), id) for id in node.ids)
5756
return cls(node, selection, None, ())
5857

5958
if isinstance(node, nodes.SelectionNode):
@@ -357,27 +356,10 @@ def decompose_conjunction(
357356
return (expr,)
358357

359358

360-
def get_node_column_ids(node: nodes.BigFrameNode) -> Tuple[ids.ColumnId, ...]:
361-
return tuple(field.id for field in node.fields)
362-
363-
364359
def common_selection_root(
365360
l_tree: nodes.BigFrameNode, r_tree: nodes.BigFrameNode
366361
) -> Optional[nodes.BigFrameNode]:
367362
"""Find common subtree between join subtrees"""
368-
l_node = l_tree
369-
l_nodes: set[nodes.BigFrameNode] = set()
370-
while isinstance(l_node, REWRITABLE_NODE_TYPES):
371-
l_nodes.add(l_node)
372-
l_node = l_node.child
373-
l_nodes.add(l_node)
374-
375-
r_node = r_tree
376-
while isinstance(r_node, REWRITABLE_NODE_TYPES):
377-
if r_node in l_nodes:
378-
return r_node
379-
r_node = r_node.child
380-
381-
if r_node in l_nodes:
382-
return r_node
383-
return None
363+
return bigframes.core.rewrite.implicit_align.first_shared_descendent(
364+
l_tree, r_tree, descendable_types=LEGACY_REWRITER_NODES
365+
)

0 commit comments

Comments
 (0)