Skip to content

Commit b454256

Browse files
authored
chore: implement eq, eq_null_match, ne compilers (#2008)
Fixes internal issue 430133370
1 parent f61b044 commit b454256

File tree

5 files changed

+214
-46
lines changed

5 files changed

+214
-46
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 62 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -38,31 +38,23 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3838
return sge.Concat(expressions=[left.expr, right.expr])
3939

4040
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
41-
left_expr = left.expr
42-
if left.dtype == dtypes.BOOL_DTYPE:
43-
left_expr = sge.Cast(this=left_expr, to="INT64")
44-
right_expr = right.expr
45-
if right.dtype == dtypes.BOOL_DTYPE:
46-
right_expr = sge.Cast(this=right_expr, to="INT64")
41+
left_expr = _coerce_bool_to_int(left)
42+
right_expr = _coerce_bool_to_int(right)
4743
return sge.Add(this=left_expr, expression=right_expr)
4844

4945
if (
5046
dtypes.is_time_or_date_like(left.dtype)
5147
and right.dtype == dtypes.TIMEDELTA_DTYPE
5248
):
53-
left_expr = left.expr
54-
if left.dtype == dtypes.DATE_DTYPE:
55-
left_expr = sge.Cast(this=left_expr, to="DATETIME")
49+
left_expr = _coerce_date_to_datetime(left)
5650
return sge.TimestampAdd(
5751
this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND")
5852
)
5953
if (
6054
dtypes.is_time_or_date_like(right.dtype)
6155
and left.dtype == dtypes.TIMEDELTA_DTYPE
6256
):
63-
right_expr = right.expr
64-
if right.dtype == dtypes.DATE_DTYPE:
65-
right_expr = sge.Cast(this=right_expr, to="DATETIME")
57+
right_expr = _coerce_date_to_datetime(right)
6658
return sge.TimestampAdd(
6759
this=right_expr, expression=left.expr, unit=sge.Var(this="MICROSECOND")
6860
)
@@ -74,14 +66,37 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7466
)
7567

7668

77-
@BINARY_OP_REGISTRATION.register(ops.div_op)
69+
@BINARY_OP_REGISTRATION.register(ops.eq_op)
70+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
71+
left_expr = _coerce_bool_to_int(left)
72+
right_expr = _coerce_bool_to_int(right)
73+
return sge.EQ(this=left_expr, expression=right_expr)
74+
75+
76+
@BINARY_OP_REGISTRATION.register(ops.eq_null_match_op)
7877
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7978
left_expr = left.expr
80-
if left.dtype == dtypes.BOOL_DTYPE:
81-
left_expr = sge.Cast(this=left_expr, to="INT64")
79+
if right.dtype != dtypes.BOOL_DTYPE:
80+
left_expr = _coerce_bool_to_int(left)
81+
8282
right_expr = right.expr
83-
if right.dtype == dtypes.BOOL_DTYPE:
84-
right_expr = sge.Cast(this=right_expr, to="INT64")
83+
if left.dtype != dtypes.BOOL_DTYPE:
84+
right_expr = _coerce_bool_to_int(right)
85+
86+
sentinel = sge.convert("$NULL_SENTINEL$")
87+
left_coalesce = sge.Coalesce(
88+
this=sge.Cast(this=left_expr, to="STRING"), expressions=[sentinel]
89+
)
90+
right_coalesce = sge.Coalesce(
91+
this=sge.Cast(this=right_expr, to="STRING"), expressions=[sentinel]
92+
)
93+
return sge.EQ(this=left_coalesce, expression=right_coalesce)
94+
95+
96+
@BINARY_OP_REGISTRATION.register(ops.div_op)
97+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
98+
left_expr = _coerce_bool_to_int(left)
99+
right_expr = _coerce_bool_to_int(right)
85100

86101
result = sge.func("IEEE_DIVIDE", left_expr, right_expr)
87102
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
@@ -92,12 +107,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
92107

93108
@BINARY_OP_REGISTRATION.register(ops.floordiv_op)
94109
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
95-
left_expr = left.expr
96-
if left.dtype == dtypes.BOOL_DTYPE:
97-
left_expr = sge.Cast(this=left_expr, to="INT64")
98-
right_expr = right.expr
99-
if right.dtype == dtypes.BOOL_DTYPE:
100-
right_expr = sge.Cast(this=right_expr, to="INT64")
110+
left_expr = _coerce_bool_to_int(left)
111+
right_expr = _coerce_bool_to_int(right)
101112

102113
result: sge.Expression = sge.Cast(
103114
this=sge.Floor(this=sge.func("IEEE_DIVIDE", left_expr, right_expr)), to="INT64"
@@ -139,12 +150,8 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
139150

140151
@BINARY_OP_REGISTRATION.register(ops.mul_op)
141152
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
142-
left_expr = left.expr
143-
if left.dtype == dtypes.BOOL_DTYPE:
144-
left_expr = sge.Cast(this=left_expr, to="INT64")
145-
right_expr = right.expr
146-
if right.dtype == dtypes.BOOL_DTYPE:
147-
right_expr = sge.Cast(this=right_expr, to="INT64")
153+
left_expr = _coerce_bool_to_int(left)
154+
right_expr = _coerce_bool_to_int(right)
148155

149156
result = sge.Mul(this=left_expr, expression=right_expr)
150157

@@ -156,36 +163,33 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
156163
return result
157164

158165

166+
@BINARY_OP_REGISTRATION.register(ops.ne_op)
167+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
168+
left_expr = _coerce_bool_to_int(left)
169+
right_expr = _coerce_bool_to_int(right)
170+
return sge.NEQ(this=left_expr, expression=right_expr)
171+
172+
159173
@BINARY_OP_REGISTRATION.register(ops.sub_op)
160174
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
161175
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
162-
left_expr = left.expr
163-
if left.dtype == dtypes.BOOL_DTYPE:
164-
left_expr = sge.Cast(this=left_expr, to="INT64")
165-
right_expr = right.expr
166-
if right.dtype == dtypes.BOOL_DTYPE:
167-
right_expr = sge.Cast(this=right_expr, to="INT64")
176+
left_expr = _coerce_bool_to_int(left)
177+
right_expr = _coerce_bool_to_int(right)
168178
return sge.Sub(this=left_expr, expression=right_expr)
169179

170180
if (
171181
dtypes.is_time_or_date_like(left.dtype)
172182
and right.dtype == dtypes.TIMEDELTA_DTYPE
173183
):
174-
left_expr = left.expr
175-
if left.dtype == dtypes.DATE_DTYPE:
176-
left_expr = sge.Cast(this=left_expr, to="DATETIME")
184+
left_expr = _coerce_date_to_datetime(left)
177185
return sge.TimestampSub(
178186
this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND")
179187
)
180188
if dtypes.is_time_or_date_like(left.dtype) and dtypes.is_time_or_date_like(
181189
right.dtype
182190
):
183-
left_expr = left.expr
184-
if left.dtype == dtypes.DATE_DTYPE:
185-
left_expr = sge.Cast(this=left_expr, to="DATETIME")
186-
right_expr = right.expr
187-
if right.dtype == dtypes.DATE_DTYPE:
188-
right_expr = sge.Cast(this=right_expr, to="DATETIME")
191+
left_expr = _coerce_date_to_datetime(left)
192+
right_expr = _coerce_date_to_datetime(right)
189193
return sge.TimestampDiff(
190194
this=left_expr, expression=right_expr, unit=sge.Var(this="MICROSECOND")
191195
)
@@ -201,3 +205,17 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
201205
@BINARY_OP_REGISTRATION.register(ops.obj_make_ref_op)
202206
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
203207
return sge.func("OBJ.MAKE_REF", left.expr, right.expr)
208+
209+
210+
def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression:
211+
"""Coerce boolean expression to integer."""
212+
if typed_expr.dtype == dtypes.BOOL_DTYPE:
213+
return sge.Cast(this=typed_expr.expr, to="INT64")
214+
return typed_expr.expr
215+
216+
217+
def _coerce_date_to_datetime(typed_expr: TypedExpr) -> sge.Expression:
218+
"""Coerce date expression to datetime."""
219+
if typed_expr.dtype == dtypes.DATE_DTYPE:
220+
return sge.Cast(this=typed_expr.expr, to="DATETIME")
221+
return typed_expr.expr
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
COALESCE(CAST(`bfcol_1` AS STRING), '$NULL_SENTINEL$') = COALESCE(CAST(CAST(`bfcol_0` AS INT64) AS STRING), '$NULL_SENTINEL$') AS `bfcol_4`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_4` AS `int64_col`
14+
FROM `bfcte_1`
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_2` AS `bfcol_6`,
11+
`bfcol_1` AS `bfcol_7`,
12+
`bfcol_0` AS `bfcol_8`,
13+
`bfcol_1` = `bfcol_1` AS `bfcol_9`
14+
FROM `bfcte_0`
15+
), `bfcte_2` AS (
16+
SELECT
17+
*,
18+
`bfcol_6` AS `bfcol_14`,
19+
`bfcol_7` AS `bfcol_15`,
20+
`bfcol_8` AS `bfcol_16`,
21+
`bfcol_9` AS `bfcol_17`,
22+
`bfcol_7` = 1 AS `bfcol_18`
23+
FROM `bfcte_1`
24+
), `bfcte_3` AS (
25+
SELECT
26+
*,
27+
`bfcol_14` AS `bfcol_24`,
28+
`bfcol_15` AS `bfcol_25`,
29+
`bfcol_16` AS `bfcol_26`,
30+
`bfcol_17` AS `bfcol_27`,
31+
`bfcol_18` AS `bfcol_28`,
32+
`bfcol_15` = CAST(`bfcol_16` AS INT64) AS `bfcol_29`
33+
FROM `bfcte_2`
34+
), `bfcte_4` AS (
35+
SELECT
36+
*,
37+
`bfcol_24` AS `bfcol_36`,
38+
`bfcol_25` AS `bfcol_37`,
39+
`bfcol_26` AS `bfcol_38`,
40+
`bfcol_27` AS `bfcol_39`,
41+
`bfcol_28` AS `bfcol_40`,
42+
`bfcol_29` AS `bfcol_41`,
43+
CAST(`bfcol_26` AS INT64) = `bfcol_25` AS `bfcol_42`
44+
FROM `bfcte_3`
45+
)
46+
SELECT
47+
`bfcol_36` AS `rowindex`,
48+
`bfcol_37` AS `int64_col`,
49+
`bfcol_38` AS `bool_col`,
50+
`bfcol_39` AS `int_ne_int`,
51+
`bfcol_40` AS `int_ne_1`,
52+
`bfcol_41` AS `int_ne_bool`,
53+
`bfcol_42` AS `bool_ne_int`
54+
FROM `bfcte_4`
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_2` AS `bfcol_6`,
11+
`bfcol_1` AS `bfcol_7`,
12+
`bfcol_0` AS `bfcol_8`,
13+
`bfcol_1` <> `bfcol_1` AS `bfcol_9`
14+
FROM `bfcte_0`
15+
), `bfcte_2` AS (
16+
SELECT
17+
*,
18+
`bfcol_6` AS `bfcol_14`,
19+
`bfcol_7` AS `bfcol_15`,
20+
`bfcol_8` AS `bfcol_16`,
21+
`bfcol_9` AS `bfcol_17`,
22+
`bfcol_7` <> 1 AS `bfcol_18`
23+
FROM `bfcte_1`
24+
), `bfcte_3` AS (
25+
SELECT
26+
*,
27+
`bfcol_14` AS `bfcol_24`,
28+
`bfcol_15` AS `bfcol_25`,
29+
`bfcol_16` AS `bfcol_26`,
30+
`bfcol_17` AS `bfcol_27`,
31+
`bfcol_18` AS `bfcol_28`,
32+
`bfcol_15` <> CAST(`bfcol_16` AS INT64) AS `bfcol_29`
33+
FROM `bfcte_2`
34+
), `bfcte_4` AS (
35+
SELECT
36+
*,
37+
`bfcol_24` AS `bfcol_36`,
38+
`bfcol_25` AS `bfcol_37`,
39+
`bfcol_26` AS `bfcol_38`,
40+
`bfcol_27` AS `bfcol_39`,
41+
`bfcol_28` AS `bfcol_40`,
42+
`bfcol_29` AS `bfcol_41`,
43+
CAST(`bfcol_26` AS INT64) <> `bfcol_25` AS `bfcol_42`
44+
FROM `bfcte_3`
45+
)
46+
SELECT
47+
`bfcol_36` AS `rowindex`,
48+
`bfcol_37` AS `int64_col`,
49+
`bfcol_38` AS `bool_col`,
50+
`bfcol_39` AS `int_ne_int`,
51+
`bfcol_40` AS `int_ne_1`,
52+
`bfcol_41` AS `int_ne_bool`,
53+
`bfcol_42` AS `bool_ne_int`
54+
FROM `bfcte_4`

tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,24 @@ def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
107107
snapshot.assert_match(bf_df.sql, "out.sql")
108108

109109

110+
def test_eq_null_match(scalar_types_df: bpd.DataFrame, snapshot):
111+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
112+
sql = _apply_binary_op(bf_df, ops.eq_null_match_op, "int64_col", "bool_col")
113+
snapshot.assert_match(sql, "out.sql")
114+
115+
116+
def test_eq_numeric(scalar_types_df: bpd.DataFrame, snapshot):
117+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
118+
119+
bf_df["int_ne_int"] = bf_df["int64_col"] == bf_df["int64_col"]
120+
bf_df["int_ne_1"] = bf_df["int64_col"] == 1
121+
122+
bf_df["int_ne_bool"] = bf_df["int64_col"] == bf_df["bool_col"]
123+
bf_df["bool_ne_int"] = bf_df["bool_col"] == bf_df["int64_col"]
124+
125+
snapshot.assert_match(bf_df.sql, "out.sql")
126+
127+
110128
def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot):
111129
bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col"]]
112130

@@ -121,8 +139,6 @@ def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot):
121139
bf_df["int_div_bool"] = bf_df["int64_col"] // bf_df["bool_col"]
122140
bf_df["bool_div_int"] = bf_df["bool_col"] // bf_df["int64_col"]
123141

124-
snapshot.assert_match(bf_df.sql, "out.sql")
125-
126142

127143
def test_floordiv_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
128144
bf_df = scalar_types_df[["timestamp_col", "date_col"]]
@@ -200,3 +216,15 @@ def test_mul_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
200216
def test_obj_make_ref(scalar_types_df: bpd.DataFrame, snapshot):
201217
blob_df = scalar_types_df["string_col"].str.to_blob()
202218
snapshot.assert_match(blob_df.to_frame().sql, "out.sql")
219+
220+
221+
def test_ne_numeric(scalar_types_df: bpd.DataFrame, snapshot):
222+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
223+
224+
bf_df["int_ne_int"] = bf_df["int64_col"] != bf_df["int64_col"]
225+
bf_df["int_ne_1"] = bf_df["int64_col"] != 1
226+
227+
bf_df["int_ne_bool"] = bf_df["int64_col"] != bf_df["bool_col"]
228+
bf_df["bool_ne_int"] = bf_df["bool_col"] != bf_df["int64_col"]
229+
230+
snapshot.assert_match(bf_df.sql, "out.sql")

0 commit comments

Comments
 (0)