Skip to content

Commit e698dbf

Browse files
refactor: Create memoized tree traversal helpers (#1198)
1 parent 71b4053 commit e698dbf

File tree

7 files changed

+65
-31
lines changed

7 files changed

+65
-31
lines changed

bigframes/core/compile/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _preprocess(self, node: nodes.BigFrameNode):
8282
if self.enable_pruning:
8383
used_fields = frozenset(field.id for field in node.fields)
8484
node = node.prune(used_fields)
85-
node = functools.cache(rewrites.replace_slice_ops)(node)
85+
node = nodes.bottom_up(node, rewrites.rewrite_slice)
8686
if self.enable_densify_ids:
8787
original_names = [id.name for id in node.ids]
8888
node, _ = rewrites.remap_variables(

bigframes/core/compile/polars/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def compile(self, array_value: bigframes.core.ArrayValue) -> pl.LazyFrame:
173173

174174
# TODO: Create standard way to configure BFET -> BFET rewrites
175175
# Polars has incomplete slice support in lazy mode
176-
node = bigframes.core.rewrite.replace_slice_ops(array_value.node)
176+
node = nodes.bottom_up(array_value.node, bigframes.core.rewrite.rewrite_slice)
177177
return self.compile_node(node)
178178

179179
@functools.singledispatchmethod

bigframes/core/nodes.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,3 +1486,56 @@ def remap_refs(
14861486
) -> BigFrameNode:
14871487
new_ids = tuple(id.remap_column_refs(mappings) for id in self.column_ids)
14881488
return dataclasses.replace(self, column_ids=new_ids) # type: ignore
1489+
1490+
1491+
# Tree operators
1492+
def top_down(
1493+
root: BigFrameNode,
1494+
transform: Callable[[BigFrameNode], BigFrameNode],
1495+
*,
1496+
memoize=False,
1497+
validate=False,
1498+
) -> BigFrameNode:
1499+
"""
1500+
Perform a top-down transformation of the BigFrameNode tree.
1501+
1502+
If memoize=True, recursive calls are memoized within the scope of the traversal only.
1503+
"""
1504+
1505+
def top_down_internal(root: BigFrameNode) -> BigFrameNode:
1506+
return transform(root).transform_children(top_down_internal)
1507+
1508+
if memoize:
1509+
# MUST reassign to the same name or caching won't work recursively
1510+
top_down_internal = functools.cache(top_down_internal)
1511+
1512+
result = top_down_internal(root)
1513+
if validate:
1514+
result.validate_tree()
1515+
return result
1516+
1517+
1518+
def bottom_up(
1519+
root: BigFrameNode,
1520+
transform: Callable[[BigFrameNode], BigFrameNode],
1521+
*,
1522+
memoize=False,
1523+
validate=False,
1524+
) -> BigFrameNode:
1525+
"""
1526+
Perform a bottom-up transformation of the BigFrameNode tree.
1527+
1528+
If memoize=True, recursive calls are memoized within the scope of the traversal only.
1529+
"""
1530+
1531+
def bottom_up_internal(root: BigFrameNode) -> BigFrameNode:
1532+
return transform(root.transform_children(bottom_up_internal))
1533+
1534+
if memoize:
1535+
# MUST reassign to the same name or caching won't work recursively
1536+
bottom_up_internal = functools.cache(bottom_up_internal)
1537+
1538+
result = bottom_up_internal(root)
1539+
if validate:
1540+
result.validate_tree()
1541+
return result

bigframes/core/rewrite/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
from bigframes.core.rewrite.identifiers import remap_variables
1616
from bigframes.core.rewrite.implicit_align import try_row_join
1717
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
18-
from bigframes.core.rewrite.slices import pullup_limit_from_slice, replace_slice_ops
18+
from bigframes.core.rewrite.slices import pullup_limit_from_slice, rewrite_slice
1919

2020
__all__ = [
2121
"legacy_join_as_projection",
2222
"try_row_join",
23-
"replace_slice_ops",
23+
"rewrite_slice",
2424
"pullup_limit_from_slice",
2525
"remap_variables",
2626
]

bigframes/core/rewrite/slices.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from __future__ import annotations
1515

1616
import functools
17-
from typing import cast, Optional, Sequence, Tuple
17+
from typing import Optional, Sequence, Tuple
1818

1919
import bigframes.core.expression as scalar_exprs
2020
import bigframes.core.guid as guids
@@ -55,18 +55,11 @@ def pullup_limit_from_slice(
5555
return root, None
5656

5757

58-
def replace_slice_ops(root: nodes.BigFrameNode) -> nodes.BigFrameNode:
59-
# TODO: we want to pull up some slices into limit op if near root.
60-
if isinstance(root, nodes.SliceNode):
61-
root = root.transform_children(replace_slice_ops)
62-
return rewrite_slice(cast(nodes.SliceNode, root))
63-
else:
64-
return root.transform_children(replace_slice_ops)
58+
def rewrite_slice(node: nodes.BigFrameNode):
59+
if not isinstance(node, nodes.SliceNode):
60+
return node
6561

66-
67-
def rewrite_slice(node: nodes.SliceNode):
6862
slice_def = (node.start, node.stop, node.step)
69-
7063
# no-op (eg. df[::1])
7164
if slices.is_noop(slice_def, node.child.row_count):
7265
return node.child

bigframes/core/tree_properties.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def select_cache_target(
8888

8989
@functools.cache
9090
def _with_caching(subtree: nodes.BigFrameNode) -> nodes.BigFrameNode:
91-
return replace_nodes(subtree, cache)
91+
return nodes.top_down(subtree, lambda x: cache.get(x, x), memoize=True)
9292

9393
def _combine_counts(
9494
left: Dict[nodes.BigFrameNode, int], right: Dict[nodes.BigFrameNode, int]
@@ -162,17 +162,3 @@ def _node_counts_inner(
162162

163163
counts = [_node_counts_inner(root) for root in forest]
164164
return functools.reduce(_combine_counts, counts, empty_counts)
165-
166-
167-
def replace_nodes(
168-
root: nodes.BigFrameNode,
169-
replacements: dict[nodes.BigFrameNode, nodes.BigFrameNode],
170-
):
171-
@functools.cache
172-
def apply_substition(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
173-
if node in replacements.keys():
174-
return replacements[node]
175-
else:
176-
return node.transform_children(apply_substition)
177-
178-
return apply_substition(root)

bigframes/session/executor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -529,7 +529,9 @@ def _wait_on_job(
529529
return results_iterator
530530

531531
def replace_cached_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode:
532-
return tree_properties.replace_nodes(node, (dict(self._cached_executions)))
532+
return nodes.top_down(
533+
node, lambda x: self._cached_executions.get(x, x), memoize=True
534+
)
533535

534536
def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue):
535537
"""

0 commit comments

Comments
 (0)