Skip to content

Commit 0197055

Browse files
committed
chore: fix ops.ToTimedeltaOp and ops.IsInOp sqlglot compiler
1 parent 6370d3b commit 0197055

File tree

7 files changed

+72
-16
lines changed

7 files changed

+72
-16
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import bigframes.core.compile.sqlglot.expressions.constants as constants
2828
from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration
2929
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
30+
import bigframes.dtypes as dtypes
3031

3132
UNARY_OP_REGISTRATION = OpRegistration()
3233

@@ -420,7 +421,28 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression:
420421

421422
@UNARY_OP_REGISTRATION.register(ops.IsInOp)
422423
def _(op: ops.IsInOp, expr: TypedExpr) -> sge.Expression:
423-
return sge.In(this=expr.expr, expressions=[sge.convert(v) for v in op.values])
424+
values = []
425+
is_numeric_expr = dtypes.is_numeric(expr.dtype)
426+
for value in op.values:
427+
if value is None:
428+
continue
429+
dtype = dtypes.bigframes_type(type(value))
430+
if expr.dtype == dtype or is_numeric_expr and dtypes.is_numeric(dtype):
431+
values.append(sge.convert(value))
432+
433+
if op.match_nulls:
434+
contains_nulls = any(_is_null(value) for value in op.values)
435+
if contains_nulls:
436+
return sge.Is(this=expr.expr, expression=sge.Null()) | sge.In(
437+
this=expr.expr, expressions=values
438+
)
439+
440+
if len(values) == 0:
441+
return sge.convert(False)
442+
443+
return sge.func(
444+
"COALESCE", sge.In(this=expr.expr, expressions=values), sge.convert(False)
445+
)
424446

425447

426448
@UNARY_OP_REGISTRATION.register(ops.isalnum_op)
@@ -767,7 +789,7 @@ def _(op: ops.ToTimedeltaOp, expr: TypedExpr) -> sge.Expression:
767789
factor = UNIT_TO_US_CONVERSION_FACTORS[op.unit]
768790
if factor != 1:
769791
value = sge.Mul(this=value, expression=sge.convert(factor))
770-
return sge.Interval(this=value, unit=sge.Identifier(this="MICROSECOND"))
792+
return value
771793

772794

773795
@UNARY_OP_REGISTRATION.register(ops.UnixMicros)
@@ -866,3 +888,9 @@ def _(op: ops.ZfillOp, expr: TypedExpr) -> sge.Expression:
866888
],
867889
default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")),
868890
)
891+
892+
893+
# Helpers
894+
def _is_null(value) -> bool:
895+
# float NaN/inf should be treated as distinct from 'true' null values
896+
return typing.cast(bool, pd.isna(value)) and not isinstance(value, float)

tests/system/small/engines/test_generic_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine):
392392
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
393393

394394

395-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
395+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
396396
def test_engines_isin_op(scalars_array_value: array_value.ArrayValue, engine):
397397
arr, col_ids = scalars_array_value.compute_values(
398398
[

tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_timedelta/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ WITH `bfcte_0` AS (
1111
`bfcol_1` AS `bfcol_8`,
1212
`bfcol_2` AS `bfcol_9`,
1313
`bfcol_0` AS `bfcol_10`,
14-
INTERVAL `bfcol_3` MICROSECOND AS `bfcol_11`
14+
`bfcol_3` AS `bfcol_11`
1515
FROM `bfcte_0`
1616
), `bfcte_2` AS (
1717
SELECT

tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_timedelta/out.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ WITH `bfcte_0` AS (
1111
`bfcol_1` AS `bfcol_8`,
1212
`bfcol_2` AS `bfcol_9`,
1313
`bfcol_0` AS `bfcol_10`,
14-
INTERVAL `bfcol_3` MICROSECOND AS `bfcol_11`
14+
`bfcol_3` AS `bfcol_11`
1515
FROM `bfcte_0`
1616
), `bfcte_2` AS (
1717
SELECT
Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,32 @@
11
WITH `bfcte_0` AS (
22
SELECT
3-
`int64_col` AS `bfcol_0`
3+
`int64_col` AS `bfcol_0`,
4+
`float64_col` AS `bfcol_1`
45
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
56
), `bfcte_1` AS (
67
SELECT
78
*,
8-
`bfcol_0` IN (1, 2, 3) AS `bfcol_1`
9+
COALESCE(`bfcol_0` IN (1, 2, 3), FALSE) AS `bfcol_2`,
10+
(
11+
`bfcol_0` IS NULL
12+
) OR `bfcol_0` IN (123456) AS `bfcol_3`,
13+
COALESCE(`bfcol_0` IN (1.0, 2.0, 3.0), FALSE) AS `bfcol_4`,
14+
FALSE AS `bfcol_5`,
15+
COALESCE(`bfcol_0` IN (2.5, 3), FALSE) AS `bfcol_6`,
16+
FALSE AS `bfcol_7`,
17+
COALESCE(`bfcol_0` IN (123456), FALSE) AS `bfcol_8`,
18+
(
19+
`bfcol_1` IS NULL
20+
) OR `bfcol_1` IN (1, 2, 3) AS `bfcol_9`
921
FROM `bfcte_0`
1022
)
1123
SELECT
12-
`bfcol_1` AS `int64_col`
24+
`bfcol_2` AS `ints`,
25+
`bfcol_3` AS `ints_w_null`,
26+
`bfcol_4` AS `floats`,
27+
`bfcol_5` AS `strings`,
28+
`bfcol_6` AS `mixed`,
29+
`bfcol_7` AS `empty`,
30+
`bfcol_8` AS `ints_wo_match_nulls`,
31+
`bfcol_9` AS `float_in_ints`
1332
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_to_timedelta/out.sql

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@ WITH `bfcte_0` AS (
88
*,
99
`bfcol_1` AS `bfcol_4`,
1010
`bfcol_0` AS `bfcol_5`,
11-
INTERVAL `bfcol_0` MICROSECOND AS `bfcol_6`
11+
`bfcol_0` AS `bfcol_6`
1212
FROM `bfcte_0`
1313
), `bfcte_2` AS (
1414
SELECT
1515
*,
1616
`bfcol_4` AS `bfcol_10`,
1717
`bfcol_5` AS `bfcol_11`,
1818
`bfcol_6` AS `bfcol_12`,
19-
INTERVAL (`bfcol_5` * 1000000) MICROSECOND AS `bfcol_13`
19+
`bfcol_5` * 1000000 AS `bfcol_13`
2020
FROM `bfcte_1`
2121
), `bfcte_3` AS (
2222
SELECT
@@ -25,7 +25,7 @@ WITH `bfcte_0` AS (
2525
`bfcol_11` AS `bfcol_19`,
2626
`bfcol_12` AS `bfcol_20`,
2727
`bfcol_13` AS `bfcol_21`,
28-
INTERVAL (`bfcol_11` * 604800000000) MICROSECOND AS `bfcol_22`
28+
`bfcol_11` * 604800000000 AS `bfcol_22`
2929
FROM `bfcte_2`
3030
)
3131
SELECT

tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -370,12 +370,21 @@ def test_invert(scalar_types_df: bpd.DataFrame, snapshot):
370370

371371

372372
def test_is_in(scalar_types_df: bpd.DataFrame, snapshot):
373-
col_name = "int64_col"
374-
bf_df = scalar_types_df[[col_name]]
375-
sql = _apply_unary_ops(
376-
bf_df, [ops.IsInOp(values=(1, 2, 3)).as_expr(col_name)], [col_name]
377-
)
373+
int_col = "int64_col"
374+
float_col = "float64_col"
375+
bf_df = scalar_types_df[[int_col, float_col]]
376+
ops_map = {
377+
"ints": ops.IsInOp(values=(1, 2, 3)).as_expr(int_col),
378+
"ints_w_null": ops.IsInOp(values=(None, 123456)).as_expr(int_col),
379+
"floats": ops.IsInOp(values=(1.0, 2.0, 3.0), match_nulls=False).as_expr(int_col),
380+
"strings": ops.IsInOp(values=("1.0", "2.0")).as_expr(int_col),
381+
"mixed": ops.IsInOp(values=("1.0", 2.5, 3)).as_expr(int_col),
382+
"empty": ops.IsInOp(values=()).as_expr(int_col),
383+
"ints_wo_match_nulls": ops.IsInOp(values=(None, 123456), match_nulls=False).as_expr(int_col),
384+
"float_in_ints": ops.IsInOp(values=(1, 2, 3, None)).as_expr(float_col),
385+
}
378386

387+
sql = _apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
379388
snapshot.assert_match(sql, "out.sql")
380389

381390

0 commit comments

Comments
 (0)