Skip to content

Commit 7ba1166

Browse files
committed
complete isin compiler
1 parent e8bfd92 commit 7ba1166

File tree

5 files changed

+123
-19
lines changed

5 files changed

+123
-19
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import bigframes.core.compile.sqlglot.expressions.constants as constants
2828
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2929
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
30+
import bigframes.dtypes as dtypes
3031

3132
UNARY_OP_REGISTRATION = OpRegistration()
3233

@@ -420,9 +421,28 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
420421

421422
@UNARY_OP_REGISTRATION.register(ops.IsInOp)
422423
def _(op: ops.IsInOp, expr: TypedExpr) -> sge.Expression:
423-
if op.values is None or len(op.values) == 0:
424+
values = []
425+
is_numeric_expr = dtypes.is_numeric(expr.dtype)
426+
for value in op.values:
427+
if value is None:
428+
continue
429+
dtype = dtypes.bigframes_type(type(value))
430+
if expr.dtype == dtype or is_numeric_expr and dtypes.is_numeric(dtype):
431+
values.append(sge.convert(value))
432+
433+
if op.match_nulls:
434+
contains_nulls = any(_is_null(value) for value in op.values)
435+
if contains_nulls:
436+
return sge.Is(this=expr.expr, expression=sge.Null()) | sge.In(
437+
this=expr.expr, expressions=values
438+
)
439+
440+
if len(values) == 0:
424441
return sge.convert(False)
425-
return sge.In(this=expr.expr, expressions=[sge.convert(v) for v in op.values])
442+
443+
return sge.func(
444+
"COALESCE", sge.In(this=expr.expr, expressions=values), sge.convert(False)
445+
)
426446

427447

428448
@UNARY_OP_REGISTRATION.register(ops.isalnum_op)
@@ -868,3 +888,9 @@ def _(op: ops.ZfillOp, expr: TypedExpr) -> sge.Expression:
868888
],
869889
default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")),
870890
)
891+
892+
893+
# Helpers
894+
def _is_null(value) -> bool:
895+
# float NaN/inf should be treated as distinct from 'true' null values
896+
return typing.cast(bool, pd.isna(value)) and not isinstance(value, float)

tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_is_in/empty.sql

Lines changed: 0 additions & 13 deletions
This file was deleted.

tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_is_in/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
`bfcol_0` IN (1, 2, 3) AS `bfcol_1`
8+
COALESCE(`bfcol_0` IN (1, 2, 3), FALSE) AS `bfcol_1`
99
FROM `bfcte_0`
1010
)
1111
SELECT
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`bytes_col` AS `bfcol_1`,
5+
`date_col` AS `bfcol_2`,
6+
`datetime_col` AS `bfcol_3`,
7+
`geography_col` AS `bfcol_4`,
8+
`int64_col` AS `bfcol_5`,
9+
`int64_too` AS `bfcol_6`,
10+
`numeric_col` AS `bfcol_7`,
11+
`float64_col` AS `bfcol_8`,
12+
`rowindex` AS `bfcol_9`,
13+
`rowindex_2` AS `bfcol_10`,
14+
`string_col` AS `bfcol_11`,
15+
`time_col` AS `bfcol_12`,
16+
`timestamp_col` AS `bfcol_13`,
17+
`duration_col` AS `bfcol_14`
18+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
19+
), `bfcte_1` AS (
20+
SELECT
21+
*,
22+
COALESCE(`bfcol_5` IN (1, 2, 3), FALSE) AS `bfcol_31`,
23+
(
24+
`bfcol_5` IS NULL
25+
) OR `bfcol_5` IN (123456) AS `bfcol_32`,
26+
COALESCE(`bfcol_5` IN (123456), FALSE) AS `bfcol_33`,
27+
COALESCE(`bfcol_5` IN (1.0, 2.0, 3.0), FALSE) AS `bfcol_34`,
28+
FALSE AS `bfcol_35`,
29+
COALESCE(`bfcol_5` IN (2.5, 3), FALSE) AS `bfcol_36`,
30+
FALSE AS `bfcol_37`,
31+
(
32+
`bfcol_8` IS NULL
33+
) OR `bfcol_8` IN (1, 2, 3) AS `bfcol_38`
34+
FROM `bfcte_0`
35+
)
36+
SELECT
37+
`bfcol_9` AS `bfuid_col_1`,
38+
`bfcol_0` AS `bool_col`,
39+
`bfcol_1` AS `bytes_col`,
40+
`bfcol_2` AS `date_col`,
41+
`bfcol_3` AS `datetime_col`,
42+
`bfcol_4` AS `geography_col`,
43+
`bfcol_5` AS `int64_col`,
44+
`bfcol_6` AS `int64_too`,
45+
`bfcol_7` AS `numeric_col`,
46+
`bfcol_8` AS `float64_col`,
47+
`bfcol_9` AS `rowindex`,
48+
`bfcol_10` AS `rowindex_2`,
49+
`bfcol_11` AS `string_col`,
50+
`bfcol_12` AS `time_col`,
51+
`bfcol_13` AS `timestamp_col`,
52+
`bfcol_14` AS `duration_col`,
53+
`bfcol_31` AS `int in ints`,
54+
`bfcol_32` AS `int in ints w null`,
55+
`bfcol_33` AS `int in ints w null wo match nulls`,
56+
`bfcol_34` AS `int in floats`,
57+
`bfcol_35` AS `int in strings`,
58+
`bfcol_36` AS `int in mixed`,
59+
`bfcol_37` AS `int in empty`,
60+
`bfcol_38` AS `float in ints`
61+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616

1717
from bigframes import operations as ops
18+
from bigframes.core import expression
1819
from bigframes.operations._op_converters import convert_index, convert_slice
1920
import bigframes.pandas as bpd
2021

@@ -307,14 +308,43 @@ def test_invert(scalar_types_df: bpd.DataFrame, snapshot):
307308

308309
def test_is_in(scalar_types_df: bpd.DataFrame, snapshot):
309310
bf_df = scalar_types_df[["int64_col"]]
310-
311311
sql = _apply_unary_op(bf_df, ops.IsInOp(values=(1, 2, 3)), "int64_col")
312312

313313
snapshot.assert_match(sql, "out.sql")
314314

315-
sql = _apply_unary_op(bf_df, ops.IsInOp(values=()), "int64_col")
316315

317-
snapshot.assert_match(sql, "empty.sql")
316+
def test_is_in_for_all_cases(scalar_types_df: bpd.DataFrame, snapshot):
317+
scalars_array_value = scalar_types_df._block.expr
318+
arr, col_ids = scalars_array_value.compute_values(
319+
[
320+
ops.IsInOp((1, 2, 3)).as_expr(expression.deref("int64_col")),
321+
ops.IsInOp((None, 123456)).as_expr(expression.deref("int64_col")),
322+
ops.IsInOp((None, 123456), match_nulls=False).as_expr(
323+
expression.deref("int64_col")
324+
),
325+
ops.IsInOp((1.0, 2.0, 3.0)).as_expr(expression.deref("int64_col")),
326+
ops.IsInOp(("1.0", "2.0")).as_expr(expression.deref("int64_col")),
327+
ops.IsInOp(("1.0", 2.5, 3)).as_expr(expression.deref("int64_col")),
328+
ops.IsInOp(()).as_expr(expression.deref("int64_col")),
329+
ops.IsInOp((1, 2, 3, None)).as_expr(expression.deref("float64_col")),
330+
]
331+
)
332+
new_names = (
333+
"int in ints",
334+
"int in ints w null",
335+
"int in ints w null wo match nulls",
336+
"int in floats",
337+
"int in strings",
338+
"int in mixed",
339+
"int in empty",
340+
"float in ints",
341+
)
342+
arr = arr.rename_columns(
343+
{old_name: new_names[i] for i, old_name in enumerate(col_ids)}
344+
)
345+
sql = arr.session._executor.to_sql(arr, enable_cache=False)
346+
347+
snapshot.assert_match(sql, "out.sql")
318348

319349

320350
def test_isalnum(scalar_types_df: bpd.DataFrame, snapshot):

0 commit comments

Comments
 (0)