Skip to content

Commit d73fe9d

Browse files
perf: Speed up tree transforms during sql compile (#1071)
1 parent 2dc22ec commit d73fe9d

File tree

4 files changed

+52
-28
lines changed

4 files changed

+52
-28
lines changed

bigframes/core/compile/compiler.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import bigframes.core.identifiers as ids
3737
import bigframes.core.nodes as nodes
3838
import bigframes.core.ordering as bf_ordering
39+
import bigframes.core.rewrite as rewrites
3940

4041
if typing.TYPE_CHECKING:
4142
import bigframes.core
@@ -48,20 +49,32 @@ class Compiler:
4849
# In unstrict mode, ordering from ReadTable or after joins may be ambiguous to improve query performance.
4950
strict: bool = True
5051
scalar_op_compiler = compile_scalar.ScalarOpCompiler()
52+
enable_pruning: bool = False
53+
54+
def _preprocess(self, node: nodes.BigFrameNode):
55+
if self.enable_pruning:
56+
used_fields = frozenset(field.id for field in node.fields)
57+
node = node.prune(used_fields)
58+
node = functools.cache(rewrites.replace_slice_ops)(node)
59+
return node
5160

5261
def compile_ordered_ir(self, node: nodes.BigFrameNode) -> compiled.OrderedIR:
53-
ir = typing.cast(compiled.OrderedIR, self.compile_node(node, True))
62+
ir = typing.cast(
63+
compiled.OrderedIR, self.compile_node(self._preprocess(node), True)
64+
)
5465
if self.strict:
5566
assert ir.has_total_order
5667
return ir
5768

5869
def compile_unordered_ir(self, node: nodes.BigFrameNode) -> compiled.UnorderedIR:
59-
return typing.cast(compiled.UnorderedIR, self.compile_node(node, False))
70+
return typing.cast(
71+
compiled.UnorderedIR, self.compile_node(self._preprocess(node), False)
72+
)
6073

6174
def compile_peak_sql(
6275
self, node: nodes.BigFrameNode, n_rows: int
6376
) -> typing.Optional[str]:
64-
return self.compile_unordered_ir(node).peek_sql(n_rows)
77+
return self.compile_unordered_ir(self._preprocess(node)).peek_sql(n_rows)
6578

6679
# TODO: Remove cache when schema no longer requires compilation to derive schema (and therefor only compiles for execution)
6780
@functools.lru_cache(maxsize=5000)

bigframes/core/nodes.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,11 @@ def explicitly_ordered(self) -> bool:
263263
def transform_children(
264264
self, t: Callable[[BigFrameNode], BigFrameNode]
265265
) -> BigFrameNode:
266-
return replace(self, child=t(self.child))
266+
transformed = replace(self, child=t(self.child))
267+
if self == transformed:
268+
# reusing existing object speeds up eq, and saves a small amount of memory
269+
return self
270+
return transformed
267271

268272
@property
269273
def order_ambiguous(self) -> bool:
@@ -350,9 +354,13 @@ def joins(self) -> bool:
350354
def transform_children(
351355
self, t: Callable[[BigFrameNode], BigFrameNode]
352356
) -> BigFrameNode:
353-
return replace(
357+
transformed = replace(
354358
self, left_child=t(self.left_child), right_child=t(self.right_child)
355359
)
360+
if self == transformed:
361+
# reusing existing object speeds up eq, and saves a small amount of memory
362+
return self
363+
return transformed
356364

357365
@property
358366
def defines_namespace(self) -> bool:
@@ -407,7 +415,11 @@ def variables_introduced(self) -> int:
407415
def transform_children(
408416
self, t: Callable[[BigFrameNode], BigFrameNode]
409417
) -> BigFrameNode:
410-
return replace(self, children=tuple(t(child) for child in self.children))
418+
transformed = replace(self, children=tuple(t(child) for child in self.children))
419+
if self == transformed:
420+
# reusing existing object speeds up eq, and saves a small amount of memory
421+
return self
422+
return transformed
411423

412424
def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
413425
# TODO: Make concat prunable, probably by redefining
@@ -451,7 +463,11 @@ def variables_introduced(self) -> int:
451463
def transform_children(
452464
self, t: Callable[[BigFrameNode], BigFrameNode]
453465
) -> BigFrameNode:
454-
return replace(self, start=t(self.start), end=t(self.end))
466+
transformed = replace(self, start=t(self.start), end=t(self.end))
467+
if self == transformed:
468+
# reusing existing object speeds up eq, and saves a small amount of memory
469+
return self
470+
return transformed
455471

456472
def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
457473
# TODO: Make FromRangeNode prunable (or convert to other node types)

bigframes/core/tree_properties.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def _node_counts_inner(
113113

114114
node_counts = _node_counts_inner(root)
115115

116+
if len(node_counts) == 0:
117+
raise ValueError("node counts should be non-zero")
118+
116119
return max(
117120
node_counts.keys(),
118121
key=lambda node: heuristic(

bigframes/session/executor.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
import bigframes.core.identifiers
4646
import bigframes.core.nodes as nodes
4747
import bigframes.core.ordering as order
48-
import bigframes.core.rewrite as rewrites
4948
import bigframes.core.schema
5049
import bigframes.core.tree_properties as tree_properties
5150
import bigframes.features
@@ -128,7 +127,7 @@ def to_sql(
128127
col_id_overrides = dict(col_id_overrides)
129128
col_id_overrides[internal_offset_col] = offset_column
130129
node = (
131-
self._get_optimized_plan(array_value.node)
130+
self._sub_cache_subtrees(array_value.node)
132131
if enable_cache
133132
else array_value.node
134133
)
@@ -279,7 +278,7 @@ def peek(
279278
"""
280279
A 'peek' efficiently accesses a small number of rows in the dataframe.
281280
"""
282-
plan = self._get_optimized_plan(array_value.node)
281+
plan = self._sub_cache_subtrees(array_value.node)
283282
if not tree_properties.can_fast_peek(plan):
284283
warnings.warn("Peeking this value cannot be done efficiently.")
285284

@@ -314,15 +313,15 @@ def head(
314313
# No user-provided ordering, so just get any N rows, its faster!
315314
return self.peek(array_value, n_rows)
316315

317-
plan = self._get_optimized_plan(array_value.node)
316+
plan = self._sub_cache_subtrees(array_value.node)
318317
if not tree_properties.can_fast_head(plan):
319318
# If can't get head fast, we are going to need to execute the whole query
320319
# Will want to do this in a way such that the result is reusable, but the first
321320
# N values can be easily extracted.
322321
# This currently requires clustering on offsets.
323322
self._cache_with_offsets(array_value)
324323
# Get a new optimized plan after caching
325-
plan = self._get_optimized_plan(array_value.node)
324+
plan = self._sub_cache_subtrees(array_value.node)
326325
assert tree_properties.can_fast_head(plan)
327326

328327
head_plan = generate_head_plan(plan, n_rows)
@@ -347,7 +346,7 @@ def get_row_count(self, array_value: bigframes.core.ArrayValue) -> int:
347346
if count is not None:
348347
return count
349348
else:
350-
row_count_plan = self._get_optimized_plan(
349+
row_count_plan = self._sub_cache_subtrees(
351350
generate_row_count_plan(array_value.node)
352351
)
353352
sql = self.compiler.compile_unordered(row_count_plan)
@@ -359,7 +358,7 @@ def _local_get_row_count(
359358
) -> Optional[int]:
360359
# optimized plan has cache materializations which will have row count metadata
361360
# that is more likely to be usable than original leaf nodes.
362-
plan = self._get_optimized_plan(array_value.node)
361+
plan = self._sub_cache_subtrees(array_value.node)
363362
return tree_properties.row_count(plan)
364363

365364
# Helpers
@@ -424,21 +423,14 @@ def _wait_on_job(
424423
self.metrics.count_job_stats(query_job)
425424
return results_iterator
426425

427-
def _get_optimized_plan(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode:
426+
def _sub_cache_subtrees(self, node: nodes.BigFrameNode) -> nodes.BigFrameNode:
428427
"""
429428
Takes the original expression tree and applies optimizations to accelerate execution.
430429
431430
At present, the only optimization is to replace subtress with cached previous materializations.
432431
"""
433432
# Apply any rewrites *after* applying cache, as cache is sensitive to exact tree structure
434-
optimized_plan = tree_properties.replace_nodes(
435-
node, (dict(self._cached_executions))
436-
)
437-
if ENABLE_PRUNING:
438-
used_fields = frozenset(field.id for field in optimized_plan.fields)
439-
optimized_plan = optimized_plan.prune(used_fields)
440-
optimized_plan = rewrites.replace_slice_ops(optimized_plan)
441-
return optimized_plan
433+
return tree_properties.replace_nodes(node, (dict(self._cached_executions)))
442434

443435
def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue):
444436
"""
@@ -448,7 +440,7 @@ def _is_trivially_executable(self, array_value: bigframes.core.ArrayValue):
448440
# Once rewriting is available, will want to rewrite before
449441
# evaluating execution cost.
450442
return tree_properties.is_trivially_executable(
451-
self._get_optimized_plan(array_value.node)
443+
self._sub_cache_subtrees(array_value.node)
452444
)
453445

454446
def _cache_with_cluster_cols(
@@ -457,7 +449,7 @@ def _cache_with_cluster_cols(
457449
"""Executes the query and uses the resulting table to rewrite future executions."""
458450

459451
sql, schema, ordering_info = self.compiler.compile_raw(
460-
self._get_optimized_plan(array_value.node)
452+
self._sub_cache_subtrees(array_value.node)
461453
)
462454
tmp_table = self._sql_as_cached_temp_table(
463455
sql,
@@ -474,7 +466,7 @@ def _cache_with_offsets(self, array_value: bigframes.core.ArrayValue):
474466
"""Executes the query and uses the resulting table to rewrite future executions."""
475467
offset_column = bigframes.core.guid.generate_guid("bigframes_offsets")
476468
w_offsets, offset_column = array_value.promote_offsets()
477-
sql = self.compiler.compile_unordered(self._get_optimized_plan(w_offsets.node))
469+
sql = self.compiler.compile_unordered(self._sub_cache_subtrees(w_offsets.node))
478470

479471
tmp_table = self._sql_as_cached_temp_table(
480472
sql,
@@ -510,7 +502,7 @@ def _simplify_with_caching(self, array_value: bigframes.core.ArrayValue):
510502
"""Attempts to handle the complexity by caching duplicated subtrees and breaking the query into pieces."""
511503
# Apply existing caching first
512504
for _ in range(MAX_SUBTREE_FACTORINGS):
513-
node_with_cache = self._get_optimized_plan(array_value.node)
505+
node_with_cache = self._sub_cache_subtrees(array_value.node)
514506
if node_with_cache.planning_complexity < QUERY_COMPLEXITY_LIMIT:
515507
return
516508

@@ -567,7 +559,7 @@ def _validate_result_schema(
567559
):
568560
actual_schema = tuple(bq_schema)
569561
ibis_schema = bigframes.core.compile.test_only_ibis_inferred_schema(
570-
self._get_optimized_plan(array_value.node)
562+
self._sub_cache_subtrees(array_value.node)
571563
)
572564
internal_schema = array_value.schema
573565
if not bigframes.features.PANDAS_VERSIONS.is_arrow_list_dtype_usable:

0 commit comments

Comments
 (0)