Skip to content

Commit 173efd9

Browse files
authored
refactor: adds null literal checks when sqlglot compiling eq and nq ops (#2381)
This change adds null literal checks during the compilation of eq, ne and map ops. This aims to resolve the `test_series_replace_nans_with_pd_na` failure reported in #2248. Fixes internal issue 417774347 🦕
1 parent a634e97 commit 173efd9

File tree

8 files changed

+109
-30
lines changed

8 files changed

+109
-30
lines changed

bigframes/core/compile/sqlglot/expressions/comparison_ops.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616

1717
import typing
1818

19+
import bigframes_vendored.sqlglot as sg
1920
import bigframes_vendored.sqlglot.expressions as sge
2021
import pandas as pd
2122

2223
from bigframes import dtypes
2324
from bigframes import operations as ops
25+
from bigframes.core.compile.sqlglot import sqlglot_ir
2426
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2527
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2628

@@ -62,6 +64,10 @@ def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression:
6264

6365
@register_binary_op(ops.eq_op)
6466
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
67+
if sqlglot_ir._is_null_literal(left.expr):
68+
return sge.Is(this=right.expr, expression=sge.Null())
69+
if sqlglot_ir._is_null_literal(right.expr):
70+
return sge.Is(this=left.expr, expression=sge.Null())
6571
left_expr = _coerce_bool_to_int(left)
6672
right_expr = _coerce_bool_to_int(right)
6773
return sge.EQ(this=left_expr, expression=right_expr)
@@ -139,6 +145,17 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
139145

140146
@register_binary_op(ops.ne_op)
141147
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
148+
if sqlglot_ir._is_null_literal(left.expr):
149+
return sge.Is(
150+
this=sge.paren(right.expr, copy=False),
151+
expression=sg.not_(sge.Null(), copy=False),
152+
)
153+
if sqlglot_ir._is_null_literal(right.expr):
154+
return sge.Is(
155+
this=sge.paren(left.expr, copy=False),
156+
expression=sg.not_(sge.Null(), copy=False),
157+
)
158+
142159
left_expr = _coerce_bool_to_int(left)
143160
right_expr = _coerce_bool_to_int(right)
144161
return sge.NEQ(this=left_expr, expression=right_expr)

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from bigframes import dtypes
2121
from bigframes import operations as ops
22-
from bigframes.core.compile.sqlglot import sqlglot_types
22+
from bigframes.core.compile.sqlglot import sqlglot_ir, sqlglot_types
2323
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2424
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2525

@@ -101,11 +101,23 @@ def _(expr: TypedExpr) -> sge.Expression:
101101
def _(expr: TypedExpr, op: ops.MapOp) -> sge.Expression:
102102
if len(op.mappings) == 0:
103103
return expr.expr
104+
105+
mappings = [
106+
(
107+
sqlglot_ir._literal(key, dtypes.is_compatible(key, expr.dtype)),
108+
sqlglot_ir._literal(value, dtypes.is_compatible(value, expr.dtype)),
109+
)
110+
for key, value in op.mappings
111+
]
104112
return sge.Case(
105-
this=expr.expr,
106113
ifs=[
107-
sge.If(this=sge.convert(key), true=sge.convert(value))
108-
for key, value in op.mappings
114+
sge.If(
115+
this=sge.EQ(this=expr.expr, expression=key)
116+
if not sqlglot_ir._is_null_literal(key)
117+
else sge.Is(this=expr.expr, expression=sge.Null()),
118+
true=value,
119+
)
120+
for key, value in mappings
109121
],
110122
default=expr.expr,
111123
)

bigframes/core/compile/sqlglot/sqlglot_ir.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,15 @@ def _select_to_cte(expr: sge.Select, cte_name: sge.Identifier) -> sge.Select:
642642
return new_select_expr
643643

644644

645+
def _is_null_literal(expr: sge.Expression) -> bool:
646+
"""Checks if the given expression is a NULL literal."""
647+
if isinstance(expr, sge.Null):
648+
return True
649+
if isinstance(expr, sge.Cast) and isinstance(expr.this, sge.Null):
650+
return True
651+
return False
652+
653+
645654
def _literal(value: typing.Any, dtype: dtypes.Dtype) -> sge.Expression:
646655
sqlglot_type = sgt.from_bigframes_dtype(dtype) if dtype else None
647656
if sqlglot_type is None:

tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_eq_numeric/out.sql

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ WITH `bfcte_0` AS (
2929
`bfcol_16` AS `bfcol_26`,
3030
`bfcol_17` AS `bfcol_27`,
3131
`bfcol_18` AS `bfcol_28`,
32-
`bfcol_15` = CAST(`bfcol_16` AS INT64) AS `bfcol_29`
32+
`bfcol_15` IS NULL AS `bfcol_29`
3333
FROM `bfcte_2`
3434
), `bfcte_4` AS (
3535
SELECT
@@ -40,15 +40,28 @@ WITH `bfcte_0` AS (
4040
`bfcol_27` AS `bfcol_39`,
4141
`bfcol_28` AS `bfcol_40`,
4242
`bfcol_29` AS `bfcol_41`,
43-
CAST(`bfcol_26` AS INT64) = `bfcol_25` AS `bfcol_42`
43+
`bfcol_25` = CAST(`bfcol_26` AS INT64) AS `bfcol_42`
4444
FROM `bfcte_3`
45+
), `bfcte_5` AS (
46+
SELECT
47+
*,
48+
`bfcol_36` AS `bfcol_50`,
49+
`bfcol_37` AS `bfcol_51`,
50+
`bfcol_38` AS `bfcol_52`,
51+
`bfcol_39` AS `bfcol_53`,
52+
`bfcol_40` AS `bfcol_54`,
53+
`bfcol_41` AS `bfcol_55`,
54+
`bfcol_42` AS `bfcol_56`,
55+
CAST(`bfcol_38` AS INT64) = `bfcol_37` AS `bfcol_57`
56+
FROM `bfcte_4`
4557
)
4658
SELECT
47-
`bfcol_36` AS `rowindex`,
48-
`bfcol_37` AS `int64_col`,
49-
`bfcol_38` AS `bool_col`,
50-
`bfcol_39` AS `int_ne_int`,
51-
`bfcol_40` AS `int_ne_1`,
52-
`bfcol_41` AS `int_ne_bool`,
53-
`bfcol_42` AS `bool_ne_int`
54-
FROM `bfcte_4`
59+
`bfcol_50` AS `rowindex`,
60+
`bfcol_51` AS `int64_col`,
61+
`bfcol_52` AS `bool_col`,
62+
`bfcol_53` AS `int_eq_int`,
63+
`bfcol_54` AS `int_eq_1`,
64+
`bfcol_55` AS `int_eq_null`,
65+
`bfcol_56` AS `int_eq_bool`,
66+
`bfcol_57` AS `bool_eq_int`
67+
FROM `bfcte_5`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_ne_numeric/out.sql

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ WITH `bfcte_0` AS (
2929
`bfcol_16` AS `bfcol_26`,
3030
`bfcol_17` AS `bfcol_27`,
3131
`bfcol_18` AS `bfcol_28`,
32-
`bfcol_15` <> CAST(`bfcol_16` AS INT64) AS `bfcol_29`
32+
(
33+
`bfcol_15`
34+
) IS NOT NULL AS `bfcol_29`
3335
FROM `bfcte_2`
3436
), `bfcte_4` AS (
3537
SELECT
@@ -40,15 +42,28 @@ WITH `bfcte_0` AS (
4042
`bfcol_27` AS `bfcol_39`,
4143
`bfcol_28` AS `bfcol_40`,
4244
`bfcol_29` AS `bfcol_41`,
43-
CAST(`bfcol_26` AS INT64) <> `bfcol_25` AS `bfcol_42`
45+
`bfcol_25` <> CAST(`bfcol_26` AS INT64) AS `bfcol_42`
4446
FROM `bfcte_3`
47+
), `bfcte_5` AS (
48+
SELECT
49+
*,
50+
`bfcol_36` AS `bfcol_50`,
51+
`bfcol_37` AS `bfcol_51`,
52+
`bfcol_38` AS `bfcol_52`,
53+
`bfcol_39` AS `bfcol_53`,
54+
`bfcol_40` AS `bfcol_54`,
55+
`bfcol_41` AS `bfcol_55`,
56+
`bfcol_42` AS `bfcol_56`,
57+
CAST(`bfcol_38` AS INT64) <> `bfcol_37` AS `bfcol_57`
58+
FROM `bfcte_4`
4559
)
4660
SELECT
47-
`bfcol_36` AS `rowindex`,
48-
`bfcol_37` AS `int64_col`,
49-
`bfcol_38` AS `bool_col`,
50-
`bfcol_39` AS `int_ne_int`,
51-
`bfcol_40` AS `int_ne_1`,
52-
`bfcol_41` AS `int_ne_bool`,
53-
`bfcol_42` AS `bool_ne_int`
54-
FROM `bfcte_4`
61+
`bfcol_50` AS `rowindex`,
62+
`bfcol_51` AS `int64_col`,
63+
`bfcol_52` AS `bool_col`,
64+
`bfcol_53` AS `int_ne_int`,
65+
`bfcol_54` AS `int_ne_1`,
66+
`bfcol_55` AS `int_ne_null`,
67+
`bfcol_56` AS `int_ne_bool`,
68+
`bfcol_57` AS `bool_ne_int`
69+
FROM `bfcte_5`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_generic_ops/test_map/out.sql

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,13 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8-
CASE `string_col` WHEN 'value1' THEN 'mapped1' ELSE `string_col` END AS `bfcol_1`
8+
CASE
9+
WHEN `string_col` = 'value1'
10+
THEN 'mapped1'
11+
WHEN `string_col` IS NULL
12+
THEN 'UNKNOWN'
13+
ELSE `string_col`
14+
END AS `bfcol_1`
915
FROM `bfcte_0`
1016
)
1117
SELECT

tests/unit/core/compile/sqlglot/expressions/test_comparison_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,12 @@ def test_eq_null_match(scalar_types_df: bpd.DataFrame, snapshot):
5959
def test_eq_numeric(scalar_types_df: bpd.DataFrame, snapshot):
6060
bf_df = scalar_types_df[["int64_col", "bool_col"]]
6161

62-
bf_df["int_ne_int"] = bf_df["int64_col"] == bf_df["int64_col"]
63-
bf_df["int_ne_1"] = bf_df["int64_col"] == 1
62+
bf_df["int_eq_int"] = bf_df["int64_col"] == bf_df["int64_col"]
63+
bf_df["int_eq_1"] = bf_df["int64_col"] == 1
64+
bf_df["int_eq_null"] = bf_df["int64_col"] == pd.NA
6465

65-
bf_df["int_ne_bool"] = bf_df["int64_col"] == bf_df["bool_col"]
66-
bf_df["bool_ne_int"] = bf_df["bool_col"] == bf_df["int64_col"]
66+
bf_df["int_eq_bool"] = bf_df["int64_col"] == bf_df["bool_col"]
67+
bf_df["bool_eq_int"] = bf_df["bool_col"] == bf_df["int64_col"]
6768

6869
snapshot.assert_match(bf_df.sql, "out.sql")
6970

@@ -135,6 +136,7 @@ def test_ne_numeric(scalar_types_df: bpd.DataFrame, snapshot):
135136

136137
bf_df["int_ne_int"] = bf_df["int64_col"] != bf_df["int64_col"]
137138
bf_df["int_ne_1"] = bf_df["int64_col"] != 1
139+
bf_df["int_ne_null"] = bf_df["int64_col"] != pd.NA
138140

139141
bf_df["int_ne_bool"] = bf_df["int64_col"] != bf_df["bool_col"]
140142
bf_df["bool_ne_int"] = bf_df["bool_col"] != bf_df["int64_col"]

tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import pandas as pd
1516
import pytest
1617

1718
from bigframes import dtypes
@@ -342,7 +343,11 @@ def test_map(scalar_types_df: bpd.DataFrame, snapshot):
342343
bf_df = scalar_types_df[[col_name]]
343344
sql = utils._apply_ops_to_sql(
344345
bf_df,
345-
[ops.MapOp(mappings=(("value1", "mapped1"),)).as_expr(col_name)],
346+
[
347+
ops.MapOp(mappings=(("value1", "mapped1"), (pd.NA, "UNKNOWN"))).as_expr(
348+
col_name
349+
)
350+
],
346351
[col_name],
347352
)
348353

0 commit comments

Comments
 (0)