Skip to content

Commit de0881b

Browse files
refactor: caching is now a session property (#697)
1 parent f89b6be commit de0881b

File tree

7 files changed

+140
-96
lines changed

7 files changed

+140
-96
lines changed

bigframes/core/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def project_to_id(self, expression: ex.Expression, output_id: str):
184184
child=self.node,
185185
assignments=tuple(exprs),
186186
)
187-
).merge_projections()
187+
)
188188

189189
def assign(self, source_id: str, destination_id: str) -> ArrayValue:
190190
if destination_id in self.column_ids: # Mutate case
@@ -209,7 +209,7 @@ def assign(self, source_id: str, destination_id: str) -> ArrayValue:
209209
child=self.node,
210210
assignments=tuple(exprs),
211211
)
212-
).merge_projections()
212+
)
213213

214214
def assign_constant(
215215
self,
@@ -243,7 +243,7 @@ def assign_constant(
243243
child=self.node,
244244
assignments=tuple(exprs),
245245
)
246-
).merge_projections()
246+
)
247247

248248
def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
249249
selections = ((ex.free_var(col_id), col_id) for col_id in column_ids)
@@ -252,7 +252,7 @@ def select_columns(self, column_ids: typing.Sequence[str]) -> ArrayValue:
252252
child=self.node,
253253
assignments=tuple(selections),
254254
)
255-
).merge_projections()
255+
)
256256

257257
def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
258258
new_projection = (
@@ -265,7 +265,7 @@ def drop_columns(self, columns: Iterable[str]) -> ArrayValue:
265265
child=self.node,
266266
assignments=tuple(new_projection),
267267
)
268-
).merge_projections()
268+
)
269269

270270
def aggregate(
271271
self,

bigframes/core/blocks.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2026,27 +2026,17 @@ def to_sql_query(
20262026
idx_labels,
20272027
)
20282028

2029-
def cached(self, *, optimize_offsets=False, force: bool = False) -> Block:
2030-
"""Write the block to a session table and create a new block object that references it."""
2029+
def cached(self, *, optimize_offsets=False, force: bool = False) -> None:
2030+
"""Write the block to a session table."""
20312031
# use a heuristic for whether something needs to be cached
20322032
if (not force) and self.session._is_trivially_executable(self.expr):
2033-
return self
2033+
return
20342034
if optimize_offsets:
2035-
expr = self.session._cache_with_offsets(self.expr)
2035+
self.session._cache_with_offsets(self.expr)
20362036
else:
2037-
expr = self.session._cache_with_cluster_cols(
2037+
self.session._cache_with_cluster_cols(
20382038
self.expr, cluster_cols=self.index_columns
20392039
)
2040-
return self.swap_array_expr(expr)
2041-
2042-
def swap_array_expr(self, expr: core.ArrayValue) -> Block:
2043-
# TODO: Validate schema unchanged
2044-
return Block(
2045-
expr,
2046-
index_columns=self.index_columns,
2047-
column_labels=self.column_labels,
2048-
index_labels=self.index.names,
2049-
)
20502040

20512041
def _is_monotonic(
20522042
self, column_ids: typing.Union[str, Sequence[str]], increasing: bool
@@ -2116,8 +2106,8 @@ def _get_rows_as_json_values(self) -> Block:
21162106
# TODO(shobs): Replace direct SQL manipulation by structured expression
21172107
# manipulation
21182108
ordering_column_name = guid.generate_guid()
2119-
expr = self.session._cache_with_offsets(self.expr)
2120-
expr = expr.promote_offsets(ordering_column_name)
2109+
self.session._cache_with_offsets(self.expr)
2110+
expr = self.expr.promote_offsets(ordering_column_name)
21212111
expr_sql = self.session._to_sql(expr)
21222112

21232113
# Names of the columns to serialize for the row.

bigframes/core/indexes/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def __new__(
9595
result = typing.cast(Index, object.__new__(klass)) # type: ignore
9696
result._query_job = None
9797
result._block = block
98+
block.session._register_object(result)
9899
return result
99100

100101
@classmethod

bigframes/core/tree_properties.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import functools
1717
import itertools
18-
from typing import Dict
18+
from typing import Callable, Dict, Optional
1919

2020
import bigframes.core.nodes as nodes
2121

@@ -40,46 +40,66 @@ def peekable(node: nodes.BigFrameNode) -> bool:
4040
return children_peekable and self_peekable
4141

4242

43-
def count_complex_nodes(
44-
root: nodes.BigFrameNode, min_complexity: float, max_complexity: float
45-
) -> Dict[nodes.BigFrameNode, int]:
43+
# Replace modified_cost(node) = cost(apply_cache(node))
44+
def select_cache_target(
45+
root: nodes.BigFrameNode,
46+
min_complexity: float,
47+
max_complexity: float,
48+
cache: dict[nodes.BigFrameNode, nodes.BigFrameNode],
49+
heuristic: Callable[[int, int], float],
50+
) -> Optional[nodes.BigFrameNode]:
51+
"""Take tree, and return candidate nodes with (# of occurences, post-caching planning complexity).
52+
53+
heurstic takes two args, node complexity, and node occurence count, in that order
54+
"""
55+
56+
@functools.cache
57+
def _with_caching(subtree: nodes.BigFrameNode) -> nodes.BigFrameNode:
58+
return replace_nodes(subtree, cache)
59+
60+
def _combine_counts(
61+
left: Dict[nodes.BigFrameNode, int], right: Dict[nodes.BigFrameNode, int]
62+
) -> Dict[nodes.BigFrameNode, int]:
63+
return {
64+
key: left.get(key, 0) + right.get(key, 0)
65+
for key in itertools.chain(left.keys(), right.keys())
66+
}
67+
4668
@functools.cache
4769
def _node_counts_inner(
4870
subtree: nodes.BigFrameNode,
4971
) -> Dict[nodes.BigFrameNode, int]:
5072
"""Helper function to count occurences of duplicate nodes in a subtree. Considers only nodes in a complexity range"""
5173
empty_counts: Dict[nodes.BigFrameNode, int] = {}
52-
if subtree.planning_complexity >= min_complexity:
74+
subtree_complexity = _with_caching(subtree).planning_complexity
75+
if subtree_complexity >= min_complexity:
5376
child_counts = [_node_counts_inner(child) for child in subtree.child_nodes]
5477
node_counts = functools.reduce(_combine_counts, child_counts, empty_counts)
55-
if subtree.planning_complexity <= max_complexity:
78+
if subtree_complexity <= max_complexity:
5679
return _combine_counts(node_counts, {subtree: 1})
5780
else:
5881
return node_counts
5982
return empty_counts
6083

61-
return _node_counts_inner(root)
84+
node_counts = _node_counts_inner(root)
85+
86+
return max(
87+
node_counts.keys(),
88+
key=lambda node: heuristic(
89+
_with_caching(node).planning_complexity, node_counts[node]
90+
),
91+
)
6292

6393

6494
def replace_nodes(
6595
root: nodes.BigFrameNode,
66-
to_replace: nodes.BigFrameNode,
67-
replacemenet: nodes.BigFrameNode,
96+
replacements: dict[nodes.BigFrameNode, nodes.BigFrameNode],
6897
):
6998
@functools.cache
70-
def apply_substition(n: nodes.BigFrameNode) -> nodes.BigFrameNode:
71-
if n == to_replace:
72-
return replacemenet
99+
def apply_substition(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
100+
if node in replacements.keys():
101+
return replacements[node]
73102
else:
74-
return n.transform_children(apply_substition)
75-
76-
return root.transform_children(apply_substition)
77-
103+
return node.transform_children(apply_substition)
78104

79-
def _combine_counts(
80-
left: Dict[nodes.BigFrameNode, int], right: Dict[nodes.BigFrameNode, int]
81-
) -> Dict[nodes.BigFrameNode, int]:
82-
return {
83-
key: left.get(key, 0) + right.get(key, 0)
84-
for key in itertools.chain(left.keys(), right.keys())
85-
}
105+
return apply_substition(root)

bigframes/dataframe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def __init__(
191191
else:
192192
self._block = bigframes.pandas.read_pandas(pd_dataframe)._get_block()
193193
self._query_job: Optional[bigquery.QueryJob] = None
194+
self._block.session._register_object(self)
194195

195196
def __dir__(self):
196197
return dir(type(self)) + [
@@ -3515,16 +3516,15 @@ def _cached(self, *, force: bool = False) -> DataFrame:
35153516
No-op if the dataframe represents a trivial transformation of an existing materialization.
35163517
Force=True is used for BQML integration where need to copy data rather than use snapshot.
35173518
"""
3518-
self._set_block(self._block.cached(force=force))
3519+
self._block.cached(force=force)
35193520
return self
35203521

35213522
def _optimize_query_complexity(self):
35223523
"""Reduce query complexity by caching repeated subtrees and recursively materializing maximum-complexity subtrees.
35233524
May generate many queries and take substantial time to execute.
35243525
"""
35253526
# TODO: Move all this to session
3526-
new_expr = self._session._simplify_with_caching(self._block.expr)
3527-
self._set_block(self._block.swap_array_expr(new_expr))
3527+
self._session._simplify_with_caching(self._block.expr)
35283528

35293529
_DataFrameOrSeries = typing.TypeVar("_DataFrameOrSeries")
35303530

bigframes/series.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class Series(bigframes.operations.base.SeriesMethods, vendored_pandas_series.Ser
7272
def __init__(self, *args, **kwargs):
7373
self._query_job: Optional[bigquery.QueryJob] = None
7474
super().__init__(*args, **kwargs)
75+
self._block.session._register_object(self)
7576

7677
@property
7778
def dt(self) -> dt.DatetimeMethods:
@@ -1777,16 +1778,15 @@ def cache(self):
17771778
return self._cached(force=True)
17781779

17791780
def _cached(self, *, force: bool = True) -> Series:
1780-
self._set_block(self._block.cached(force=force))
1781+
self._block.cached(force=force)
17811782
return self
17821783

17831784
def _optimize_query_complexity(self):
17841785
"""Reduce query complexity by caching repeated subtrees and recursively materializing maximum-complexity subtrees.
17851786
May generate many queries and take substantial time to execute.
17861787
"""
17871788
# TODO: Move all this to session
1788-
new_expr = self._block.session._simplify_with_caching(self._block.expr)
1789-
self._set_block(self._block.swap_array_expr(new_expr))
1789+
self._block.session._simplify_with_caching(self._block.expr)
17901790

17911791

17921792
def _is_list_like(obj: typing.Any) -> typing_extensions.TypeGuard[typing.Sequence]:

0 commit comments

Comments
 (0)