Skip to content

Commit 1b96b80

Browse files
fix: Self-join optimization doesn't needlessly invalidate caching (#797)
1 parent 2ebb618 commit 1b96b80

File tree

3 files changed

+91
-41
lines changed

3 files changed

+91
-41
lines changed

bigframes/core/__init__.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -505,11 +505,11 @@ def try_align_as_projection(
505505
join_type: join_def.JoinType,
506506
mappings: typing.Tuple[join_def.JoinColumnMapping, ...],
507507
) -> typing.Optional[ArrayValue]:
508-
left_side = bigframes.core.rewrite.SquashedSelect.from_node(self.node)
509-
right_side = bigframes.core.rewrite.SquashedSelect.from_node(other.node)
510-
result = left_side.maybe_merge(right_side, join_type, mappings)
508+
result = bigframes.core.rewrite.join_as_projection(
509+
self.node, other.node, mappings, join_type
510+
)
511511
if result is not None:
512-
return ArrayValue(result.expand())
512+
return ArrayValue(result)
513513
return None
514514

515515
def explode(self, column_ids: typing.Sequence[str]) -> ArrayValue:
@@ -528,7 +528,3 @@ def _uniform_sampling(self, fraction: float) -> ArrayValue:
528528
The row numbers of result is non-deterministic, avoid to use.
529529
"""
530530
return ArrayValue(nodes.RandomSampleNode(self.node, fraction))
531-
532-
def merge_projections(self) -> ArrayValue:
533-
new_node = bigframes.core.rewrite.maybe_squash_projection(self.node)
534-
return ArrayValue(new_node)

bigframes/core/rewrite.py

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,17 @@
2626

2727
Selection = Tuple[Tuple[scalar_exprs.Expression, str], ...]
2828

29+
REWRITABLE_NODE_TYPES = (
30+
nodes.ProjectionNode,
31+
nodes.FilterNode,
32+
nodes.ReversedNode,
33+
nodes.OrderByNode,
34+
)
35+
2936

3037
@dataclasses.dataclass(frozen=True)
3138
class SquashedSelect:
32-
"""Squash together as many nodes as possible, separating out the projection, filter and reordering expressions."""
39+
"""Squash nodes together until target node, separating out the projection, filter and reordering expressions."""
3340

3441
root: nodes.BigFrameNode
3542
columns: Tuple[Tuple[scalar_exprs.Expression, str], ...]
@@ -38,25 +45,25 @@ class SquashedSelect:
3845
reverse_root: bool = False
3946

4047
@classmethod
41-
def from_node(
42-
cls, node: nodes.BigFrameNode, projections_only: bool = False
48+
def from_node_span(
49+
cls, node: nodes.BigFrameNode, target: nodes.BigFrameNode
4350
) -> SquashedSelect:
44-
if isinstance(node, nodes.ProjectionNode):
45-
return cls.from_node(node.child, projections_only=projections_only).project(
46-
node.assignments
47-
)
48-
elif not projections_only and isinstance(node, nodes.FilterNode):
49-
return cls.from_node(node.child).filter(node.predicate)
50-
elif not projections_only and isinstance(node, nodes.ReversedNode):
51-
return cls.from_node(node.child).reverse()
52-
elif not projections_only and isinstance(node, nodes.OrderByNode):
53-
return cls.from_node(node.child).order_with(node.by)
54-
else:
51+
if node == target:
5552
selection = tuple(
5653
(scalar_exprs.UnboundVariableExpression(id), id)
5754
for id in get_node_column_ids(node)
5855
)
5956
return cls(node, selection, None, ())
57+
if isinstance(node, nodes.ProjectionNode):
58+
return cls.from_node_span(node.child, target).project(node.assignments)
59+
elif isinstance(node, nodes.FilterNode):
60+
return cls.from_node_span(node.child, target).filter(node.predicate)
61+
elif isinstance(node, nodes.ReversedNode):
62+
return cls.from_node_span(node.child, target).reverse()
63+
elif isinstance(node, nodes.OrderByNode):
64+
return cls.from_node_span(node.child, target).order_with(node.by)
65+
else:
66+
raise ValueError(f"Cannot rewrite node {node}")
6067

6168
@property
6269
def column_lookup(self) -> Mapping[str, scalar_exprs.Expression]:
@@ -98,9 +105,10 @@ def order_with(self, by: Tuple[order.OrderingExpression, ...]):
98105
self.root, self.columns, self.predicate, new_ordering, self.reverse_root
99106
)
100107

101-
def can_join(
108+
def can_merge(
102109
self, right: SquashedSelect, join_def: join_defs.JoinDefinition
103110
) -> bool:
111+
"""Determines whether the two selections can be merged into a single selection."""
104112
if join_def.type == "cross":
105113
# Cannot convert cross join to projection
106114
return False
@@ -116,14 +124,14 @@ def can_join(
116124
return False
117125
return True
118126

119-
def maybe_merge(
127+
def merge(
120128
self,
121129
right: SquashedSelect,
122130
join_type: join_defs.JoinType,
123131
mappings: Tuple[join_defs.JoinColumnMapping, ...],
124-
) -> Optional[SquashedSelect]:
132+
) -> SquashedSelect:
125133
if self.root != right.root:
126-
return None
134+
raise ValueError("Cannot merge expressions with different roots")
127135
# Mask columns and remap names to expected schema
128136
lselection = self.columns
129137
rselection = right.columns
@@ -196,28 +204,40 @@ def expand(self) -> nodes.BigFrameNode:
196204
return nodes.ProjectionNode(child=root, assignments=self.columns)
197205

198206

199-
def maybe_squash_projection(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
200-
if isinstance(node, nodes.ProjectionNode) and isinstance(
201-
node.child, nodes.ProjectionNode
202-
):
203-
# Conservative approach, only squash consecutive projections, even though could also squash filters, reorderings
204-
return SquashedSelect.from_node(node, projections_only=True).expand()
205-
return node
206-
207-
208207
def maybe_rewrite_join(join_node: nodes.JoinNode) -> nodes.BigFrameNode:
209-
left_side = SquashedSelect.from_node(join_node.left_child)
210-
right_side = SquashedSelect.from_node(join_node.right_child)
211-
if left_side.can_join(right_side, join_node.join):
212-
merged = left_side.maybe_merge(
208+
rewrite_common_node = common_selection_root(
209+
join_node.left_child, join_node.right_child
210+
)
211+
if rewrite_common_node is None:
212+
return join_node
213+
left_side = SquashedSelect.from_node_span(join_node.left_child, rewrite_common_node)
214+
right_side = SquashedSelect.from_node_span(
215+
join_node.right_child, rewrite_common_node
216+
)
217+
if left_side.can_merge(right_side, join_node.join):
218+
return left_side.merge(
213219
right_side, join_node.join.type, join_node.join.mappings
214-
)
220+
).expand()
221+
return join_node
222+
223+
224+
def join_as_projection(
225+
l_node: nodes.BigFrameNode,
226+
r_node: nodes.BigFrameNode,
227+
mappings: Tuple[join_defs.JoinColumnMapping, ...],
228+
how: join_defs.JoinType,
229+
) -> Optional[nodes.BigFrameNode]:
230+
rewrite_common_node = common_selection_root(l_node, r_node)
231+
if rewrite_common_node is not None:
232+
left_side = SquashedSelect.from_node_span(l_node, rewrite_common_node)
233+
right_side = SquashedSelect.from_node_span(r_node, rewrite_common_node)
234+
merged = left_side.merge(right_side, how, mappings)
215235
assert (
216236
merged is not None
217237
), "Couldn't merge nodes. This shouldn't happen. Please share full stacktrace with the BigQuery DataFrames team at [email protected]."
218238
return merged.expand()
219239
else:
220-
return join_node
240+
return None
221241

222242

223243
def remap_names(
@@ -311,3 +331,25 @@ def get_node_column_ids(node: nodes.BigFrameNode) -> Tuple[str, ...]:
311331
import bigframes.core
312332

313333
return tuple(bigframes.core.ArrayValue(node).column_ids)
334+
335+
336+
def common_selection_root(
337+
l_tree: nodes.BigFrameNode, r_tree: nodes.BigFrameNode
338+
) -> Optional[nodes.BigFrameNode]:
339+
"""Find common subtree between join subtrees"""
340+
l_node = l_tree
341+
l_nodes: set[nodes.BigFrameNode] = set()
342+
while isinstance(l_node, REWRITABLE_NODE_TYPES):
343+
l_nodes.add(l_node)
344+
l_node = l_node.child
345+
l_nodes.add(l_node)
346+
347+
r_node = r_tree
348+
while isinstance(r_node, REWRITABLE_NODE_TYPES):
349+
if r_node in l_nodes:
350+
return r_node
351+
r_node = r_node.child
352+
353+
if r_node in l_nodes:
354+
return r_node
355+
return None

tests/system/small/test_dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4328,6 +4328,18 @@ def test_df_cached(scalars_df_index):
43284328
pandas.testing.assert_frame_equal(df.to_pandas(), df_cached_copy.to_pandas())
43294329

43304330

4331+
def test_df_cache_with_implicit_join(scalars_df_index):
4332+
"""expectation is that cache will be used, but no explicit join will be performed"""
4333+
df = scalars_df_index[["int64_col", "int64_too"]].sort_index().reset_index() + 3
4334+
df.cache()
4335+
bf_result = df + (df * 2)
4336+
sql = bf_result.sql
4337+
4338+
# Very crude asserts, want sql to not use join and not use base table, only reference cached table
4339+
assert "JOIN" not in sql
4340+
assert "bigframes_testing" not in sql
4341+
4342+
43314343
def test_df_dot_inline(session):
43324344
df1 = pd.DataFrame([[1, 2, 3], [2, 5, 7]])
43334345
df2 = pd.DataFrame([[2, 4, 8], [1, 5, 10], [3, 6, 9]])

0 commit comments

Comments
 (0)