Skip to content

Commit eae36b8

Browse files
feat: Add isin local execution impl
1 parent 0bd5e1b commit eae36b8

File tree

5 files changed

+69
-6
lines changed

5 files changed

+69
-6
lines changed

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1062,7 +1062,7 @@ def isin_op_impl(x: ibis_types.Value, op: ops.IsInOp):
10621062
if op.match_nulls and contains_nulls:
10631063
return x.isnull() | x.isin(matchable_ibis_values)
10641064
else:
1065-
return x.isin(matchable_ibis_values)
1065+
return x.isin(matchable_ibis_values).fillna(False)
10661066

10671067

10681068
@scalar_op_compiler.register_unary_op(ops.ToDatetimeOp, pass_op=True)

bigframes/core/compile/polars/compiler.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,11 +259,9 @@ def _(self, op: ops.ScalarOp, l_input: pl.Expr, r_input: pl.Expr) -> pl.Expr:
259259
def _(self, op: ops.ScalarOp, input: pl.Expr) -> pl.Expr:
260260
# TODO: Filter out types that can't be coerced to right type
261261
assert isinstance(op, gen_ops.IsInOp)
262-
if op.match_nulls or not any(map(pd.isna, op.values)):
263-
# newer polars version have nulls_equal arg
264-
return input.is_in(op.values)
265-
else:
266-
return input.is_in(op.values) or input.is_null()
262+
assert not op.match_nulls # should be stripped by a lowering step rn
263+
values = pl.Series(op.values, strict=False)
264+
return input.is_in(values)
267265

268266
@compile_op.register(gen_ops.FillNaOp)
269267
@compile_op.register(gen_ops.CoalesceOp)

bigframes/core/compile/polars/lowering.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414

1515
import dataclasses
16+
from typing import cast
1617

1718
import numpy as np
19+
import pandas as pd
1820

1921
from bigframes import dtypes
2022
from bigframes.core import bigframe_node, expression
@@ -316,6 +318,35 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression:
316318
return expr
317319

318320

321+
class LowerIsinOp(op_lowering.OpLoweringRule):
322+
@property
323+
def op(self) -> type[ops.ScalarOp]:
324+
return generic_ops.IsInOp
325+
326+
def lower(self, expr: expression.OpExpression) -> expression.Expression:
327+
assert isinstance(expr.op, generic_ops.IsInOp)
328+
arg = expr.children[0]
329+
new_values = []
330+
match_nulls = False
331+
for val in expr.op.values:
332+
# coercible, non-coercible
333+
# float NaN/inf should be treated as distinct from 'true' null values
334+
if cast(bool, pd.isna(val)) and not isinstance(val, float):
335+
if expr.op.match_nulls:
336+
match_nulls = True
337+
elif dtypes.is_compatible(val, arg.output_type):
338+
new_values.append(val)
339+
else:
340+
pass
341+
342+
new_isin = ops.IsInOp(tuple(new_values), match_nulls=False).as_expr(arg)
343+
if match_nulls:
344+
return ops.coalesce_op.as_expr(new_isin, expression.const(True))
345+
else:
346+
# polars propagates nulls, so need to coalesce to false
347+
return ops.coalesce_op.as_expr(new_isin, expression.const(False))
348+
349+
319350
def _coerce_comparables(
320351
expr1: expression.Expression,
321352
expr2: expression.Expression,
@@ -414,6 +445,7 @@ def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression):
414445
LowerModRule(),
415446
LowerAsTypeRule(),
416447
LowerInvertOp(),
448+
LowerIsinOp(),
417449
)
418450

419451

bigframes/session/polars_executor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
generic_ops.FillNaOp,
6464
generic_ops.CaseWhenOp,
6565
generic_ops.InvertOp,
66+
generic_ops.IsInOp,
6667
)
6768
_COMPATIBLE_AGG_OPS = (
6869
agg_ops.SizeOp,

tests/system/small/engines/test_generic_ops.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,35 @@ def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine):
390390
)
391391

392392
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
393+
394+
395+
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
396+
def test_engines_isin_op(scalars_array_value: array_value.ArrayValue, engine):
397+
arr, col_ids = scalars_array_value.compute_values(
398+
[
399+
ops.IsInOp((1, 2, 3)).as_expr(expression.deref("int64_col")),
400+
ops.IsInOp((None, 123456)).as_expr(expression.deref("int64_col")),
401+
ops.IsInOp((None, 123456), match_nulls=False).as_expr(
402+
expression.deref("int64_col")
403+
),
404+
ops.IsInOp((1.0, 2.0, 3.0)).as_expr(expression.deref("int64_col")),
405+
ops.IsInOp(("1.0", "2.0")).as_expr(expression.deref("int64_col")),
406+
ops.IsInOp(("1.0", 2.5, 3)).as_expr(expression.deref("int64_col")),
407+
ops.IsInOp(()).as_expr(expression.deref("int64_col")),
408+
ops.IsInOp((1, 2, 3, None)).as_expr(expression.deref("float64_col")),
409+
]
410+
)
411+
new_names = (
412+
"int in ints",
413+
"int in ints w null",
414+
"int in ints w null wo match nulls",
415+
"int in floats",
416+
"int in strings",
417+
"int in mixed",
418+
"int in empty",
419+
"float in ints",
420+
)
421+
arr = arr.rename_columns(
422+
{old_name: new_names[i] for i, old_name in enumerate(col_ids)}
423+
)
424+
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)

0 commit comments

Comments
 (0)