Skip to content

Commit 8715105

Browse files
feat: Add simple stats support to hybrid local pushdown (#1873)
1 parent dba2a6e commit 8715105

File tree

3 files changed

+52
-5
lines changed

3 files changed

+52
-5
lines changed

bigframes/session/polars_executor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,15 @@
4747
bigframes.operations.ge_op,
4848
bigframes.operations.le_op,
4949
)
50-
_COMPATIBLE_AGG_OPS = (agg_ops.SizeOp, agg_ops.SizeUnaryOp)
50+
_COMPATIBLE_AGG_OPS = (
51+
agg_ops.SizeOp,
52+
agg_ops.SizeUnaryOp,
53+
agg_ops.MinOp,
54+
agg_ops.MaxOp,
55+
agg_ops.SumOp,
56+
agg_ops.MeanOp,
57+
agg_ops.CountOp,
58+
)
5159

5260

5361
def _get_expr_ops(expr: expression.Expression) -> set[bigframes.operations.ScalarOp]:

bigframes/testing/engine_utils.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pandas.testing
16+
1517
from bigframes.core import nodes
1618
from bigframes.session import semi_executor
1719

@@ -25,7 +27,8 @@ def assert_equivalence_execution(
2527
e2_result = engine2.execute(node, ordered=True)
2628
assert e1_result is not None
2729
assert e2_result is not None
28-
# Schemas might have extra nullity markers, normalize to node expected schema, which should be looser
29-
e1_table = e1_result.to_arrow_table().cast(node.schema.to_pyarrow())
30-
e2_table = e2_result.to_arrow_table().cast(node.schema.to_pyarrow())
31-
assert e1_table.equals(e2_table), f"{e1_table} is not equal to {e2_table}"
30+
# Convert to pandas, as pandas has better comparison utils than arrow
31+
assert e1_result.schema == e2_result.schema
32+
e1_table = e1_result.to_pandas()
33+
e2_table = e2_result.to_pandas()
34+
pandas.testing.assert_frame_equal(e1_table, e2_table, rtol=1e-10)

tests/system/small/engines/test_aggregation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,28 @@
2525
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
2626

2727

28+
def apply_agg_to_all_valid(
29+
array: array_value.ArrayValue, op: agg_ops.UnaryAggregateOp, excluded_cols=[]
30+
) -> array_value.ArrayValue:
31+
"""
32+
Apply the aggregation to every column in the array that has a compatible datatype.
33+
"""
34+
exprs_by_name = []
35+
for arg in array.column_ids:
36+
if arg in excluded_cols:
37+
continue
38+
try:
39+
_ = op.output_type(array.get_column_type(arg))
40+
expr = expression.UnaryAggregation(op, expression.deref(arg))
41+
name = f"{arg}-{op.name}"
42+
exprs_by_name.append((expr, name))
43+
except TypeError:
44+
continue
45+
assert len(exprs_by_name) > 0
46+
new_arr = array.aggregate(exprs_by_name)
47+
return new_arr
48+
49+
2850
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
2951
def test_engines_aggregate_size(
3052
scalars_array_value: array_value.ArrayValue,
@@ -48,6 +70,20 @@ def test_engines_aggregate_size(
4870
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
4971

5072

73+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
74+
@pytest.mark.parametrize(
75+
"op",
76+
[agg_ops.min_op, agg_ops.max_op, agg_ops.mean_op, agg_ops.sum_op, agg_ops.count_op],
77+
)
78+
def test_engines_unary_aggregates(
79+
scalars_array_value: array_value.ArrayValue,
80+
engine,
81+
op,
82+
):
83+
node = apply_agg_to_all_valid(scalars_array_value, op).node
84+
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
85+
86+
5187
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
5288
@pytest.mark.parametrize(
5389
"grouping_cols",

0 commit comments

Comments
 (0)