Skip to content

Commit 26df6e6

Browse files
feat: Add isin local execution impl (#1993)
1 parent b454256 commit 26df6e6

File tree

7 files changed

+88
-9
lines changed

7 files changed

+88
-9
lines changed

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ def isin_op_impl(x: ibis_types.Value, op: ops.IsInOp):
10621062
if op.match_nulls and contains_nulls:
10631063
return x.isnull() | x.isin(matchable_ibis_values)
10641064
else:
1065-
return x.isin(matchable_ibis_values)
1065+
return x.isin(matchable_ibis_values).fillna(False)
10661066

10671067

10681068
@scalar_op_compiler.register_unary_op(ops.ToDatetimeOp, pass_op=True)

bigframes/core/compile/polars/compiler.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,9 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
263263
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
264264
# TODO: Filter out types that can't be coerced to right type
265265
assert isinstance(op, gen_ops.IsInOp)
266-
if op.match_nulls or not any(map(pd.isna, op.values)):
267-
# newer polars version have nulls_equal arg
268-
return input.is_in(op.values)
269-
else:
270-
return input.is_in(op.values) or input.is_null()
266+
assert not op.match_nulls # should be stripped by a lowering step rn
267+
values = pl.Series(op.values, strict=False)
268+
return input.is_in(values)
271269

272270
@compile_op.register(gen_ops.FillNaOp)
273271
@compile_op.register(gen_ops.CoalesceOp)

bigframes/core/compile/polars/lowering.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16+
from typing import cast
1617

1718
import numpy as np
19+
import pandas as pd
1820

1921
from bigframes import dtypes
2022
from bigframes.core import bigframe_node, expression
@@ -316,6 +318,35 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
316318
return expr
317319

318320

321+
class LowerIsinOp(op_lowering.OpLoweringRule):
322+
@property
323+
def op(self) -> type[ops.ScalarOp]:
324+
return generic_ops.IsInOp
325+
326+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
327+
assert isinstance(expr.op, generic_ops.IsInOp)
328+
arg = expr.children[0]
329+
new_values = []
330+
match_nulls = False
331+
for val in expr.op.values:
332+
# coercible, non-coercible
333+
# float NaN/inf should be treated as distinct from 'true' null values
334+
if cast(bool, pd.isna(val)) and not isinstance(val, float):
335+
if expr.op.match_nulls:
336+
match_nulls = True
337+
elif dtypes.is_compatible(val, arg.output_type):
338+
new_values.append(val)
339+
else:
340+
pass
341+
342+
new_isin = ops.IsInOp(tuple(new_values), match_nulls=False).as_expr(arg)
343+
if match_nulls:
344+
return ops.coalesce_op.as_expr(new_isin, expression.const(True))
345+
else:
346+
# polars propagates nulls, so need to coalesce to false
347+
return ops.coalesce_op.as_expr(new_isin, expression.const(False))
348+
349+
319350
def _coerce_comparables(
320351
expr1: expression.Expression,
321352
expr2: expression.Expression,
@@ -414,6 +445,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
414445
LowerModRule(),
415446
LowerAsTypeRule(),
416447
LowerInvertOp(),
448+
LowerIsinOp(),
417449
)
418450

419451

bigframes/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2755,11 +2755,11 @@ def isin(self, values) -> DataFrame:
27552755
False, label=label, dtype=pandas.BooleanDtype()
27562756
)
27572757
result_ids.append(result_id)
2758-
return DataFrame(block.select_columns(result_ids)).fillna(value=False)
2758+
return DataFrame(block.select_columns(result_ids))
27592759
elif utils.is_list_like(values):
27602760
return self._apply_unary_op(
27612761
ops.IsInOp(values=tuple(values), match_nulls=True)
2762-
).fillna(value=False)
2762+
)
27632763
else:
27642764
raise TypeError(
27652765
"only list-like objects are allowed to be passed to "

bigframes/session/polars_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
generic_ops.FillNaOp,
6767
generic_ops.CaseWhenOp,
6868
generic_ops.InvertOp,
69+
generic_ops.IsInOp,
6970
generic_ops.IsNullOp,
7071
generic_ops.NotNullOp,
7172
)

tests/system/small/engines/test_generic_ops.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,36 @@ def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine):
390390
)
391391

392392
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
393+
394+
395+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
396+
def test_engines_isin_op(scalars_array_value: array_value.ArrayValue, engine):
397+
arr, col_ids = scalars_array_value.compute_values(
398+
[
399+
ops.IsInOp((1, 2, 3)).as_expr(expression.deref("int64_col")),
400+
ops.IsInOp((None, 123456)).as_expr(expression.deref("int64_col")),
401+
ops.IsInOp((None, 123456), match_nulls=False).as_expr(
402+
expression.deref("int64_col")
403+
),
404+
ops.IsInOp((1.0, 2.0, 3.0)).as_expr(expression.deref("int64_col")),
405+
ops.IsInOp(("1.0", "2.0")).as_expr(expression.deref("int64_col")),
406+
ops.IsInOp(("1.0", 2.5, 3)).as_expr(expression.deref("int64_col")),
407+
ops.IsInOp(()).as_expr(expression.deref("int64_col")),
408+
ops.IsInOp((1, 2, 3, None)).as_expr(expression.deref("float64_col")),
409+
]
410+
)
411+
new_names = (
412+
"int in ints",
413+
"int in ints w null",
414+
"int in ints w null wo match nulls",
415+
"int in floats",
416+
"int in strings",
417+
"int in mixed",
418+
"int in empty",
419+
"float in ints",
420+
)
421+
arr = arr.rename_columns(
422+
{old_name: new_names[i] for i, old_name in enumerate(col_ids)}
423+
)
424+
425+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)

tests/system/small/test_dataframe.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1591,7 +1591,7 @@ def test_itertuples(scalars_df_index, index, name):
15911591
assert bf_tuple == pd_tuple
15921592

15931593

1594-
def test_df_isin_list(scalars_dfs):
1594+
def test_df_isin_list_w_null(scalars_dfs):
15951595
scalars_df, scalars_pandas_df = scalars_dfs
15961596
values = ["Hello, World!", 55555, 2.51, pd.NA, True]
15971597
bf_result = (
@@ -1606,6 +1606,21 @@ def test_df_isin_list(scalars_dfs):
16061606
pandas.testing.assert_frame_equal(bf_result, pd_result.astype("boolean"))
16071607

16081608

1609+
def test_df_isin_list_wo_null(scalars_dfs):
1610+
scalars_df, scalars_pandas_df = scalars_dfs
1611+
values = ["Hello, World!", 55555, 2.51, True]
1612+
bf_result = (
1613+
scalars_df[["int64_col", "float64_col", "string_col", "bool_col"]]
1614+
.isin(values)
1615+
.to_pandas()
1616+
)
1617+
pd_result = scalars_pandas_df[
1618+
["int64_col", "float64_col", "string_col", "bool_col"]
1619+
].isin(values)
1620+
1621+
pandas.testing.assert_frame_equal(bf_result, pd_result.astype("boolean"))
1622+
1623+
16091624
def test_df_isin_dict(scalars_dfs):
16101625
scalars_df, scalars_pandas_df = scalars_dfs
16111626
values = {

0 commit comments

Comments
 (0)