Skip to content

Commit 59c52a5

Browse files
feat: Or, And, Xor can execute locally (#1994)
1 parent e83836e commit 59c52a5

File tree

3 files changed

+74
-1
lines changed

3 files changed

+74
-1
lines changed

bigframes/core/compile/polars/compiler.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,10 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
198198
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
199199
return l_input | r_input
200200

201+
@compile_op.register(bool_ops.XorOp)
202+
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
203+
return l_input ^ r_input
204+
201205
@compile_op.register(num_ops.AddOp)
202206
def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
203207
return l_input + r_input

bigframes/session/polars_executor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from bigframes.core import array_value, bigframe_node, expression, local_data, nodes
2222
import bigframes.operations
2323
from bigframes.operations import aggregations as agg_ops
24-
from bigframes.operations import comparison_ops, generic_ops, numeric_ops
24+
from bigframes.operations import bool_ops, comparison_ops, generic_ops, numeric_ops
2525
from bigframes.session import executor, semi_executor
2626

2727
if TYPE_CHECKING:
@@ -44,6 +44,9 @@
4444
)
4545

4646
_COMPATIBLE_SCALAR_OPS = (
47+
bool_ops.AndOp,
48+
bool_ops.OrOp,
49+
bool_ops.XorOp,
4750
comparison_ops.EqOp,
4851
comparison_ops.EqNullsMatchOp,
4952
comparison_ops.NeOp,
@@ -63,6 +66,8 @@
6366
generic_ops.FillNaOp,
6467
generic_ops.CaseWhenOp,
6568
generic_ops.InvertOp,
69+
generic_ops.IsNullOp,
70+
generic_ops.NotNullOp,
6671
)
6772
_COMPATIBLE_AGG_OPS = (
6873
agg_ops.SizeOp,
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import itertools
16+
17+
import pytest
18+
19+
from bigframes.core import array_value
20+
import bigframes.operations as ops
21+
from bigframes.session import polars_executor
22+
from bigframes.testing.engine_utils import assert_equivalence_execution
23+
24+
pytest.importorskip("polars")
25+
26+
# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree.
27+
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
28+
29+
30+
def apply_op_pairwise(
31+
array: array_value.ArrayValue, op: ops.BinaryOp, excluded_cols=[]
32+
) -> array_value.ArrayValue:
33+
exprs = []
34+
for l_arg, r_arg in itertools.permutations(array.column_ids, 2):
35+
if (l_arg in excluded_cols) or (r_arg in excluded_cols):
36+
continue
37+
try:
38+
_ = op.output_type(
39+
array.get_column_type(l_arg), array.get_column_type(r_arg)
40+
)
41+
exprs.append(op.as_expr(l_arg, r_arg))
42+
except TypeError:
43+
continue
44+
assert len(exprs) > 0
45+
new_arr, _ = array.compute_values(exprs)
46+
return new_arr
47+
48+
49+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
50+
@pytest.mark.parametrize(
51+
"op",
52+
[
53+
ops.and_op,
54+
ops.or_op,
55+
ops.xor_op,
56+
],
57+
)
58+
def test_engines_project_boolean_op(
59+
scalars_array_value: array_value.ArrayValue, engine, op
60+
):
61+
# exclude string cols as does not contain dates
62+
# bool col actually doesn't work properly for bq engine
63+
arr = apply_op_pairwise(scalars_array_value, op, excluded_cols=["string_col"])
64+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)

0 commit comments

Comments
 (0)