Skip to content

Commit ac55aae

Browse files
test: Cross validate sort execution between engines (#1823)
1 parent a4205f8 commit ac55aae

File tree

5 files changed

+197
-9
lines changed

5 files changed

+197
-9
lines changed

bigframes/core/bigframe_node.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,19 @@
2020
import functools
2121
import itertools
2222
import typing
23-
from typing import Callable, Dict, Generator, Iterable, Mapping, Sequence, Set, Tuple
24-
25-
from bigframes.core import field, identifiers
23+
from typing import (
24+
Callable,
25+
Dict,
26+
Generator,
27+
Iterable,
28+
Mapping,
29+
Sequence,
30+
Set,
31+
Tuple,
32+
Union,
33+
)
34+
35+
from bigframes.core import expression, field, identifiers
2636
import bigframes.core.schema as schemata
2737
import bigframes.dtypes
2838

@@ -278,6 +288,13 @@ def _dtype_lookup(self) -> dict[identifiers.ColumnId, bigframes.dtypes.Dtype]:
278288
def field_by_id(self) -> Mapping[identifiers.ColumnId, field.Field]:
279289
return {field.id: field for field in self.fields}
280290

291+
@property
292+
def _node_expressions(
293+
self,
294+
) -> Sequence[Union[expression.Expression, expression.Aggregation]]:
295+
"""List of scalar expressions. Intended for checking engine compatibility with used ops."""
296+
return ()
297+
281298
# Plan algorithms
282299
def unique_nodes(
283300
self: BigFrameNode,

bigframes/core/nodes.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,10 @@ def joins_nulls(self) -> bool:
274274
right_nullable = self.right_child.field_by_id[self.right_col.id].nullable
275275
return left_nullable or right_nullable
276276

277+
@property
278+
def _node_expressions(self):
279+
return (self.left_col, self.right_col)
280+
277281
def replace_additive_base(self, node: BigFrameNode):
278282
return dataclasses.replace(self, left_child=node)
279283

@@ -387,6 +391,10 @@ def referenced_ids(self) -> COLUMN_SET:
387391
def consumed_ids(self) -> COLUMN_SET:
388392
return frozenset(*self.ids, *self.referenced_ids)
389393

394+
@property
395+
def _node_expressions(self):
396+
return tuple(itertools.chain.from_iterable(self.conditions))
397+
390398
def transform_children(self, t: Callable[[BigFrameNode], BigFrameNode]) -> JoinNode:
391399
transformed = dataclasses.replace(
392400
self, left_child=t(self.left_child), right_child=t(self.right_child)
@@ -996,6 +1004,10 @@ def consumed_ids(self) -> COLUMN_SET:
9961004
def referenced_ids(self) -> COLUMN_SET:
9971005
return frozenset(self.predicate.column_references)
9981006

1007+
@property
1008+
def _node_expressions(self):
1009+
return (self.predicate,)
1010+
9991011
def remap_vars(
10001012
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
10011013
) -> FilterNode:
@@ -1050,6 +1062,10 @@ def referenced_ids(self) -> COLUMN_SET:
10501062
itertools.chain.from_iterable(map(lambda x: x.referenced_columns, self.by))
10511063
)
10521064

1065+
@property
1066+
def _node_expressions(self):
1067+
return tuple(map(lambda x: x.scalar_expression, self.by))
1068+
10531069
def remap_vars(
10541070
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
10551071
) -> OrderByNode:
@@ -1178,6 +1194,10 @@ def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
11781194
def consumed_ids(self) -> COLUMN_SET:
11791195
return frozenset(ref.id for ref, id in self.input_output_pairs)
11801196

1197+
@property
1198+
def _node_expressions(self):
1199+
return tuple(ref for ref, id in self.input_output_pairs)
1200+
11811201
def get_id_mapping(self) -> dict[identifiers.ColumnId, identifiers.ColumnId]:
11821202
return {ref.id: id for ref, id in self.input_output_pairs}
11831203

@@ -1265,6 +1285,10 @@ def referenced_ids(self) -> COLUMN_SET:
12651285
)
12661286
)
12671287

1288+
@property
1289+
def _node_expressions(self):
1290+
return tuple(ex for ex, id in self.assignments)
1291+
12681292
@property
12691293
def additive_base(self) -> BigFrameNode:
12701294
return self.child
@@ -1361,6 +1385,13 @@ def has_ordered_ops(self) -> bool:
13611385
aggregate.op.order_independent for aggregate, _ in self.aggregations
13621386
)
13631387

1388+
@property
1389+
def _node_expressions(self):
1390+
by_ids = (ref for ref in self.by_column_ids)
1391+
aggs = tuple(agg for agg, _ in self.aggregations)
1392+
order_ids = tuple(part.scalar_expression for part in self.order_by)
1393+
return (*by_ids, *aggs, *order_ids)
1394+
13641395
def remap_vars(
13651396
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
13661397
) -> AggregateNode:
@@ -1463,6 +1494,10 @@ def inherits_order(self) -> bool:
14631494
def additive_base(self) -> BigFrameNode:
14641495
return self.child
14651496

1497+
@property
1498+
def _node_expressions(self):
1499+
return (self.expression, *self.window_spec.expressions)
1500+
14661501
def replace_additive_base(self, node: BigFrameNode) -> WindowOpNode:
14671502
return dataclasses.replace(self, child=node)
14681503

@@ -1584,6 +1619,10 @@ def node_defined_ids(self) -> Tuple[identifiers.ColumnId, ...]:
15841619
def referenced_ids(self) -> COLUMN_SET:
15851620
return frozenset(ref.id for ref in self.column_ids)
15861621

1622+
@property
1623+
def _node_expressions(self):
1624+
return self.column_ids
1625+
15871626
def remap_vars(
15881627
self, mappings: Mapping[identifiers.ColumnId, identifiers.ColumnId]
15891628
) -> ExplodeNode:
@@ -1657,6 +1696,10 @@ def row_count(self) -> Optional[int]:
16571696
def variables_introduced(self) -> int:
16581697
return 0
16591698

1699+
@property
1700+
def _node_expressions(self):
1701+
return tuple(ref for ref, _ in self.output_cols)
1702+
16601703

16611704
# Tree operators
16621705
def top_down(

bigframes/core/window_spec.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dataclasses import dataclass, replace
1717
import datetime
1818
import itertools
19-
from typing import Literal, Mapping, Optional, Set, Tuple, Union
19+
from typing import Literal, Mapping, Optional, Sequence, Set, Tuple, Union
2020

2121
import numpy as np
2222
import pandas as pd
@@ -260,6 +260,11 @@ def is_unbounded(self):
260260
self.bounds.start is None and self.bounds.end is None
261261
)
262262

263+
@property
264+
def expressions(self) -> Sequence[ex.Expression]:
265+
ordering_exprs = (item.scalar_expression for item in self.ordering)
266+
return (*self.grouping_keys, *ordering_exprs)
267+
263268
@property
264269
def all_referenced_columns(self) -> Set[ids.ColumnId]:
265270
"""

bigframes/session/polars_executor.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,26 +13,46 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import itertools
1617
from typing import Optional, TYPE_CHECKING
1718

1819
import pyarrow as pa
1920

20-
from bigframes.core import array_value, bigframe_node, local_data, nodes
21+
from bigframes.core import array_value, bigframe_node, expression, local_data, nodes
22+
import bigframes.operations
2123
from bigframes.session import executor, semi_executor
2224

2325
if TYPE_CHECKING:
2426
import polars as pl
2527

26-
28+
# Polars executor can execute more node types, but these are the validated ones
2729
_COMPATIBLE_NODES = (
2830
nodes.ReadLocalNode,
2931
nodes.OrderByNode,
3032
nodes.ReversedNode,
3133
nodes.SelectionNode,
32-
nodes.FilterNode, # partial support
33-
nodes.ProjectionNode, # partial support
3434
)
3535

36+
_COMPATIBLE_SCALAR_OPS = ()
37+
38+
39+
def _get_expr_ops(expr: expression.Expression) -> set[bigframes.operations.ScalarOp]:
40+
if isinstance(expr, expression.OpExpression):
41+
return set(itertools.chain.from_iterable(map(_get_expr_ops, expr.children)))
42+
return set()
43+
44+
45+
def _is_node_polars_executable(node: nodes.BigFrameNode):
46+
if not isinstance(node, _COMPATIBLE_NODES):
47+
return False
48+
for expr in node._node_expressions:
49+
if isinstance(expr, expression.Aggregation):
50+
return False
51+
if isinstance(expr, expression.Expression):
52+
if not _get_expr_ops(expr).issubset(_COMPATIBLE_SCALAR_OPS):
53+
return False
54+
return True
55+
3656

3757
class PolarsExecutor(semi_executor.SemiExecutor):
3858
def __init__(self):
@@ -67,7 +87,7 @@ def execute(
6787
)
6888

6989
def _can_execute(self, plan: bigframe_node.BigFrameNode):
70-
return all(isinstance(node, _COMPATIBLE_NODES) for node in plan.unique_nodes())
90+
return all(_is_node_polars_executable(node) for node in plan.unique_nodes())
7191

7292
def _adapt_array(self, array: pa.Array) -> pa.Array:
7393
target_type = local_data.logical_type_replacements(array.type)
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
from bigframes.core import array_value, nodes, ordering
18+
import bigframes.operations as bf_ops
19+
from bigframes.session import polars_executor
20+
from bigframes.testing.engine_utils import assert_equivalence_execution
21+
22+
pytest.importorskip("polars")
23+
24+
# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree.
25+
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
26+
27+
28+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
29+
def test_engines_reverse(
30+
scalars_array_value: array_value.ArrayValue,
31+
engine,
32+
):
33+
node = apply_reverse(scalars_array_value.node)
34+
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
35+
36+
37+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
38+
def test_engines_double_reverse(
39+
scalars_array_value: array_value.ArrayValue,
40+
engine,
41+
):
42+
node = apply_reverse(scalars_array_value.node)
43+
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
44+
45+
46+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
47+
@pytest.mark.parametrize(
48+
"sort_col",
49+
[
50+
"bool_col",
51+
"int64_col",
52+
"bytes_col",
53+
"date_col",
54+
"datetime_col",
55+
"int64_col",
56+
"int64_too",
57+
"numeric_col",
58+
"float64_col",
59+
"string_col",
60+
"time_col",
61+
"timestamp_col",
62+
],
63+
)
64+
def test_engines_sort_over_column(
65+
scalars_array_value: array_value.ArrayValue, engine, sort_col
66+
):
67+
node = apply_reverse(scalars_array_value.node)
68+
ORDER_EXPRESSIONS = (ordering.descending_over(sort_col, nulls_last=False),)
69+
node = nodes.OrderByNode(node, ORDER_EXPRESSIONS)
70+
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
71+
72+
73+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
74+
def test_engines_sort_multi_column_refs(
75+
scalars_array_value: array_value.ArrayValue,
76+
engine,
77+
):
78+
node = scalars_array_value.node
79+
ORDER_EXPRESSIONS = (
80+
ordering.ascending_over("bool_col", nulls_last=False),
81+
ordering.descending_over("int64_col"),
82+
)
83+
node = nodes.OrderByNode(node, ORDER_EXPRESSIONS)
84+
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
85+
86+
87+
@pytest.mark.parametrize("engine", ["polars"], indirect=True)
88+
def test_polars_engines_skips_unrecognized_order_expr(
89+
scalars_array_value: array_value.ArrayValue,
90+
engine,
91+
):
92+
node = scalars_array_value.node
93+
ORDER_EXPRESSIONS = (
94+
ordering.OrderingExpression(
95+
scalar_expression=bf_ops.sin_op.as_expr("float_col")
96+
),
97+
)
98+
node = nodes.OrderByNode(node, ORDER_EXPRESSIONS)
99+
assert engine.execute(node, ordered=True) is None
100+
101+
102+
def apply_reverse(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
103+
return nodes.ReversedNode(node)

0 commit comments

Comments
 (0)