Skip to content

Commit 3c314c3

Browse files
refactor: Unify compile paths with ResultNode (#1636)
1 parent f68b80c commit 3c314c3

File tree

10 files changed

+224
-116
lines changed

10 files changed

+224
-116
lines changed

bigframes/core/compile/api.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
if TYPE_CHECKING:
2424
import bigframes.core.nodes
2525
import bigframes.core.ordering
26-
import bigframes.core.schema
2726

2827

2928
class SQLCompiler:
@@ -35,8 +34,8 @@ def compile(
3534
limit: Optional[int] = None,
3635
) -> str:
3736
"""Compile node into sql where rows are sorted with ORDER BY."""
38-
# If we are ordering the query anyways, compiling the slice as a limit is probably a good idea.
39-
return compiler.compile_sql(node, ordered=ordered, limit=limit)
37+
request = compiler.CompileRequest(node, sort_rows=ordered, peek_count=limit)
38+
return compiler.compile_sql(request).sql
4039

4140
def compile_raw(
4241
self,
@@ -45,15 +44,20 @@ def compile_raw(
4544
str, Sequence[bigquery.SchemaField], bigframes.core.ordering.RowOrdering
4645
]:
4746
"""Compile node into sql that exposes all columns, including hidden ordering-only columns."""
48-
return compiler.compile_raw(node)
47+
request = compiler.CompileRequest(
48+
node, sort_rows=False, materialize_all_order_keys=True
49+
)
50+
result = compiler.compile_sql(request)
51+
assert result.row_order is not None
52+
return result.sql, result.sql_schema, result.row_order
4953

5054

5155
def test_only_ibis_inferred_schema(node: bigframes.core.nodes.BigFrameNode):
5256
"""Use only for testing paths to ensure ibis inferred schema does not diverge from bigframes inferred schema."""
5357
import bigframes.core.schema
5458

5559
node = compiler._replace_unsupported_ops(node)
56-
node, _ = rewrite.pull_up_order(node, order_root=False)
60+
node = rewrite.bake_order(node)
5761
ir = compiler.compile_node(node)
5862
items = tuple(
5963
bigframes.core.schema.SchemaItem(name, ir.get_column_type(ibis_id))

bigframes/core/compile/compiled.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,23 +69,28 @@ def __init__(
6969

7070
def to_sql(
7171
self,
72-
*,
73-
order_by: Sequence[OrderingExpression] = (),
74-
limit: Optional[int] = None,
75-
selections: Optional[Sequence[str]] = None,
72+
order_by: Sequence[OrderingExpression],
73+
limit: Optional[int],
74+
selections: tuple[tuple[ex.DerefOp, str], ...],
7675
) -> str:
7776
ibis_table = self._to_ibis_expr()
7877
# This set of output transforms maybe should be its own output node??
79-
if (
80-
order_by
81-
or limit
82-
or (selections and (tuple(selections) != tuple(self.column_ids)))
83-
):
78+
79+
selection_strings = tuple((ref.id.sql, name) for ref, name in selections)
80+
81+
names_preserved = tuple(name for _, name in selections) == tuple(
82+
self.column_ids
83+
)
84+
is_noop_selection = (
85+
all((i[0] == i[1] for i in selection_strings)) and names_preserved
86+
)
87+
88+
if order_by or limit or not is_noop_selection:
8489
sql = ibis_bigquery.Backend().compile(ibis_table)
8590
sql = (
8691
bigframes.core.compile.googlesql.Select()
8792
.from_(sql)
88-
.select(selections or self.column_ids)
93+
.select(selection_strings)
8994
.sql()
9095
)
9196

bigframes/core/compile/compiler.py

Lines changed: 57 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import dataclasses
1617
import functools
1718
import typing
19+
from typing import cast, Optional
1820

1921
import bigframes_vendored.ibis.backends.bigquery as ibis_bigquery
2022
import bigframes_vendored.ibis.expr.api as ibis_api
@@ -24,6 +26,7 @@
2426
import pyarrow as pa
2527

2628
from bigframes import dtypes, operations
29+
from bigframes.core import expression
2730
import bigframes.core.compile.compiled as compiled
2831
import bigframes.core.compile.concat as concat_impl
2932
import bigframes.core.compile.explode
@@ -34,48 +37,58 @@
3437

3538
if typing.TYPE_CHECKING:
3639
import bigframes.core
37-
import bigframes.session
3840

3941

40-
def compile_sql(
41-
node: nodes.BigFrameNode,
42-
ordered: bool,
43-
limit: typing.Optional[int] = None,
44-
) -> str:
45-
# later steps might add ids, so snapshot before those steps.
46-
output_ids = node.schema.names
47-
if ordered:
48-
# Need to do this before replacing unsupported ops, as that will rewrite slice ops
49-
node, pulled_up_limit = rewrites.pullup_limit_from_slice(node)
50-
if (pulled_up_limit is not None) and (
51-
(limit is None) or limit > pulled_up_limit
52-
):
53-
limit = pulled_up_limit
42+
@dataclasses.dataclass(frozen=True)
43+
class CompileRequest:
44+
node: nodes.BigFrameNode
45+
sort_rows: bool
46+
materialize_all_order_keys: bool = False
47+
peek_count: typing.Optional[int] = None
48+
49+
50+
@dataclasses.dataclass(frozen=True)
51+
class CompileResult:
52+
sql: str
53+
sql_schema: typing.Sequence[google.cloud.bigquery.SchemaField]
54+
row_order: Optional[bf_ordering.RowOrdering]
5455

55-
node = _replace_unsupported_ops(node)
56+
57+
def compile_sql(request: CompileRequest) -> CompileResult:
58+
output_names = tuple((expression.DerefOp(id), id.sql) for id in request.node.ids)
59+
result_node = nodes.ResultNode(
60+
request.node,
61+
output_cols=output_names,
62+
limit=request.peek_count,
63+
)
64+
if request.sort_rows:
65+
# Can only pullup slice if we are doing ORDER BY in outermost SELECT
66+
# Need to do this before replacing unsupported ops, as that will rewrite slice ops
67+
result_node = rewrites.pull_up_limits(result_node)
68+
result_node = _replace_unsupported_ops(result_node)
5669
# prune before pulling up order to avoid unnnecessary row_number() ops
57-
node = rewrites.column_pruning(node)
58-
node, ordering = rewrites.pull_up_order(node, order_root=ordered)
59-
# final pruning to cleanup up any leftovers unused values
60-
node = rewrites.column_pruning(node)
61-
return compile_node(node).to_sql(
62-
order_by=ordering.all_ordering_columns if ordered else (),
63-
limit=limit,
64-
selections=output_ids,
70+
result_node = cast(nodes.ResultNode, rewrites.column_pruning(result_node))
71+
result_node = rewrites.defer_order(
72+
result_node, output_hidden_row_keys=request.materialize_all_order_keys
6573
)
74+
if request.sort_rows:
75+
result_node = cast(nodes.ResultNode, rewrites.column_pruning(result_node))
76+
sql = compile_result_node(result_node)
77+
return CompileResult(
78+
sql, result_node.schema.to_bigquery(), result_node.order_by
79+
)
6680

67-
68-
def compile_raw(
69-
node: nodes.BigFrameNode,
70-
) -> typing.Tuple[
71-
str, typing.Sequence[google.cloud.bigquery.SchemaField], bf_ordering.RowOrdering
72-
]:
73-
node = _replace_unsupported_ops(node)
74-
node = rewrites.column_pruning(node)
75-
node, ordering = rewrites.pull_up_order(node, order_root=True)
76-
node = rewrites.column_pruning(node)
77-
sql = compile_node(node).to_sql()
78-
return sql, node.schema.to_bigquery(), ordering
81+
ordering: Optional[bf_ordering.RowOrdering] = result_node.order_by
82+
result_node = dataclasses.replace(result_node, order_by=None)
83+
result_node = cast(nodes.ResultNode, rewrites.column_pruning(result_node))
84+
sql = compile_result_node(result_node)
85+
# Return the ordering iff no extra columns are needed to define the row order
86+
if ordering is not None:
87+
output_order = (
88+
ordering if ordering.referenced_columns.issubset(result_node.ids) else None
89+
)
90+
assert (not request.materialize_all_order_keys) or (output_order is not None)
91+
return CompileResult(sql, result_node.schema.to_bigquery(), output_order)
7992

8093

8194
def _replace_unsupported_ops(node: nodes.BigFrameNode):
@@ -86,6 +99,14 @@ def _replace_unsupported_ops(node: nodes.BigFrameNode):
8699
return node
87100

88101

102+
def compile_result_node(root: nodes.ResultNode) -> str:
103+
return compile_node(root.child).to_sql(
104+
order_by=root.order_by.all_ordering_columns if root.order_by else (),
105+
limit=root.limit,
106+
selections=root.output_cols,
107+
)
108+
109+
89110
# TODO: Remove cache when schema no longer requires compilation to derive schema (and therefor only compiles for execution)
90111
@functools.lru_cache(maxsize=5000)
91112
def compile_node(node: nodes.BigFrameNode) -> compiled.UnorderedIR:

bigframes/core/compile/googlesql/expression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
* `expression`: Models basic SQL expressions.
2626
2727
Extended classes (not part of standard GoogleSQL syntax, but added for convenience):
28-
28+
i
2929
* `ColumnExpression`: Represents column references.
3030
* `TableExpression`: Represents table references.
3131
* `AliasExpression`: Represents aliased expressions.

bigframes/core/compile/googlesql/query.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,22 +63,31 @@ class Select(abc.SQLSyntax):
6363

6464
def select(
6565
self,
66-
columns: typing.Union[typing.Iterable[str], str, None] = None,
66+
columns: typing.Union[
67+
typing.Iterable[str], typing.Iterable[tuple[str, str]], str, None
68+
] = None,
6769
distinct: bool = False,
6870
) -> Select:
6971
if isinstance(columns, str):
7072
columns = [columns]
7173
self.select_list: typing.List[typing.Union[SelectExpression, SelectAll]] = (
72-
[
73-
SelectExpression(expression=expr.ColumnExpression(name=column))
74-
for column in columns
75-
]
74+
[self._select_field(column) for column in columns]
7675
if columns
7776
else [SelectAll(expression=expr.StarExpression())]
7877
)
7978
self.distinct = distinct
8079
return self
8180

81+
def _select_field(self, field) -> SelectExpression:
82+
if isinstance(field, str):
83+
return SelectExpression(expression=expr.ColumnExpression(name=field))
84+
85+
else:
86+
alias = field[1] if (field[0] != field[1]) else None
87+
return SelectExpression(
88+
expression=expr.ColumnExpression(name=field[0]), alias=alias
89+
)
90+
8291
def from_(
8392
self,
8493
sources: typing.Union[TABLE_SOURCE_TYPE, typing.Iterable[TABLE_SOURCE_TYPE]],

bigframes/core/nodes.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from bigframes.core import identifiers, local_data
3737
from bigframes.core.bigframe_node import BigFrameNode, COLUMN_SET, Field
3838
import bigframes.core.expression as ex
39-
from bigframes.core.ordering import OrderingExpression
39+
from bigframes.core.ordering import OrderingExpression, RowOrdering
4040
import bigframes.core.slices as slices
4141
import bigframes.core.window_spec as window
4242
import bigframes.dtypes
@@ -1602,11 +1602,50 @@ def remap_refs(
16021602

16031603

16041604
# Introduced during planing/compilation
1605+
# TODO: Enforce more strictly that this should never be a child node
16051606
@dataclasses.dataclass(frozen=True, eq=False)
16061607
class ResultNode(UnaryNode):
1607-
output_names: tuple[str, ...]
1608-
order_by: Tuple[OrderingExpression, ...] = ()
1608+
output_cols: tuple[tuple[ex.DerefOp, str], ...]
1609+
order_by: Optional[RowOrdering] = None
16091610
limit: Optional[int] = None
1611+
# TODO: CTE definitions
1612+
1613+
@property
1614+
def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
1615+
return ()
1616+
1617+
def remap_vars(
1618+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1619+
) -> ResultNode:
1620+
return self
1621+
1622+
def remap_refs(
1623+
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
1624+
) -> ResultNode:
1625+
output_names = tuple(
1626+
(ref.remap_column_refs(mappings), name) for ref, name in self.output_cols
1627+
)
1628+
order_by = self.order_by.remap_column_refs(mappings) if self.order_by else None
1629+
return dataclasses.replace(self, output_names=output_names, order_by=order_by) # type: ignore
1630+
1631+
@property
1632+
def consumed_ids(self) -> COLUMN_SET:
1633+
out_refs = frozenset(ref.id for ref, _ in self.output_cols)
1634+
order_refs = self.order_by.referenced_columns if self.order_by else frozenset()
1635+
return out_refs | order_refs
1636+
1637+
@property
1638+
def row_count(self) -> Optional[int]:
1639+
child_count = self.child.row_count
1640+
if child_count is None:
1641+
return None
1642+
if self.limit is None:
1643+
return child_count
1644+
return min(self.limit, child_count)
1645+
1646+
@property
1647+
def variables_introduced(self) -> int:
1648+
return 0
16101649

16111650

16121651
# Tree operators

bigframes/core/rewrite/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from bigframes.core.rewrite.identifiers import remap_variables
1616
from bigframes.core.rewrite.implicit_align import try_row_join
1717
from bigframes.core.rewrite.legacy_align import legacy_join_as_projection
18-
from bigframes.core.rewrite.order import pull_up_order
18+
from bigframes.core.rewrite.order import bake_order, defer_order
1919
from bigframes.core.rewrite.pruning import column_pruning
2020
from bigframes.core.rewrite.scan_reduction import try_reduce_to_table_scan
21-
from bigframes.core.rewrite.slices import pullup_limit_from_slice, rewrite_slice
21+
from bigframes.core.rewrite.slices import pull_up_limits, rewrite_slice
2222
from bigframes.core.rewrite.timedeltas import rewrite_timedelta_expressions
2323
from bigframes.core.rewrite.windows import rewrite_range_rolling
2424

@@ -27,10 +27,11 @@
2727
"try_row_join",
2828
"rewrite_slice",
2929
"rewrite_timedelta_expressions",
30-
"pullup_limit_from_slice",
30+
"pull_up_limits",
3131
"remap_variables",
32-
"pull_up_order",
32+
"defer_order",
3333
"column_pruning",
3434
"rewrite_range_rolling",
3535
"try_reduce_to_table_scan",
36+
"bake_order",
3637
]

bigframes/core/rewrite/order.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,40 @@
1515
import functools
1616
from typing import Mapping, Tuple
1717

18-
from bigframes.core import identifiers
19-
import bigframes.core.expression
18+
from bigframes.core import expression, identifiers
2019
import bigframes.core.nodes
2120
import bigframes.core.ordering
2221
import bigframes.core.window_spec
23-
import bigframes.operations
2422
from bigframes.operations import aggregations as agg_ops
2523

2624

25+
def defer_order(
26+
root: bigframes.core.nodes.ResultNode, output_hidden_row_keys: bool
27+
) -> bigframes.core.nodes.ResultNode:
28+
new_child, order = _pull_up_order(root.child, order_root=True)
29+
order_by = (
30+
order.with_ordering_columns(root.order_by.all_ordering_columns)
31+
if root.order_by
32+
else order
33+
)
34+
if output_hidden_row_keys:
35+
output_names = tuple((expression.DerefOp(id), id.sql) for id in new_child.ids)
36+
else:
37+
output_names = root.output_cols
38+
return dataclasses.replace(
39+
root, output_cols=output_names, child=new_child, order_by=order_by
40+
)
41+
42+
43+
def bake_order(
44+
node: bigframes.core.nodes.BigFrameNode,
45+
) -> bigframes.core.nodes.BigFrameNode:
46+
node, _ = _pull_up_order(node, order_root=False)
47+
return node
48+
49+
2750
# Makes ordering explicit in window definitions
28-
def pull_up_order(
51+
def _pull_up_order(
2952
root: bigframes.core.nodes.BigFrameNode,
3053
*,
3154
order_root: bool = True,

0 commit comments

Comments
 (0)