Skip to content

Commit 3664162

Browse files
committed
chore: implement eq, eq_null_match, ne compilers
1 parent 9af7130 commit 3664162

File tree

5 files changed

+198
-21
lines changed

5 files changed

+198
-21
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,7 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3737
return sge.Concat(expressions=[left.expr, right.expr])
3838

3939
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
40-
left_expr = left.expr
41-
if left.dtype == dtypes.BOOL_DTYPE:
42-
left_expr = sge.Cast(this=left_expr, to="INT64")
43-
right_expr = right.expr
44-
if right.dtype == dtypes.BOOL_DTYPE:
45-
right_expr = sge.Cast(this=right_expr, to="INT64")
40+
left_expr, right_expr = _coerce_bools(left, right)
4641
return sge.Add(this=left_expr, expression=right_expr)
4742

4843
if (
@@ -73,15 +68,36 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7368
)
7469

7570

76-
@BINARY_OP_REGISTRATION.register(ops.div_op)
71+
@BINARY_OP_REGISTRATION.register(ops.eq_op)
72+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
73+
left_expr, right_expr = _coerce_bools(left, right)
74+
return sge.EQ(this=left_expr, expression=right_expr)
75+
76+
77+
@BINARY_OP_REGISTRATION.register(ops.eq_null_match_op)
7778
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7879
left_expr = left.expr
79-
if left.dtype == dtypes.BOOL_DTYPE:
80+
if left.dtype == dtypes.BOOL_DTYPE and right.dtype != dtypes.BOOL_DTYPE:
8081
left_expr = sge.Cast(this=left_expr, to="INT64")
82+
8183
right_expr = right.expr
82-
if right.dtype == dtypes.BOOL_DTYPE:
84+
if right.dtype == dtypes.BOOL_DTYPE and left.dtype != dtypes.BOOL_DTYPE:
8385
right_expr = sge.Cast(this=right_expr, to="INT64")
8486

87+
sentinel = sge.convert("$NULL_SENTINEL$")
88+
left_coalesce = sge.Coalesce(
89+
this=sge.Cast(this=left_expr, to="STRING"), expressions=[sentinel]
90+
)
91+
right_coalesce = sge.Coalesce(
92+
this=sge.Cast(this=right_expr, to="STRING"), expressions=[sentinel]
93+
)
94+
return sge.EQ(this=left_coalesce, expression=right_coalesce)
95+
96+
97+
@BINARY_OP_REGISTRATION.register(ops.div_op)
98+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
99+
left_expr, right_expr = _coerce_bools(left, right)
100+
85101
result = sge.func("IEEE_DIVIDE", left_expr, right_expr)
86102
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
87103
return sge.Cast(this=sge.Floor(this=result), to="INT64")
@@ -101,12 +117,7 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
101117

102118
@BINARY_OP_REGISTRATION.register(ops.mul_op)
103119
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
104-
left_expr = left.expr
105-
if left.dtype == dtypes.BOOL_DTYPE:
106-
left_expr = sge.Cast(this=left_expr, to="INT64")
107-
right_expr = right.expr
108-
if right.dtype == dtypes.BOOL_DTYPE:
109-
right_expr = sge.Cast(this=right_expr, to="INT64")
120+
left_expr, right_expr = _coerce_bools(left, right)
110121

111122
result = sge.Mul(this=left_expr, expression=right_expr)
112123

@@ -118,15 +129,16 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
118129
return result
119130

120131

132+
@BINARY_OP_REGISTRATION.register(ops.ne_op)
133+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
134+
left_expr, right_expr = _coerce_bools(left, right)
135+
return sge.NEQ(this=left_expr, expression=right_expr)
136+
137+
121138
@BINARY_OP_REGISTRATION.register(ops.sub_op)
122139
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
123140
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
124-
left_expr = left.expr
125-
if left.dtype == dtypes.BOOL_DTYPE:
126-
left_expr = sge.Cast(this=left_expr, to="INT64")
127-
right_expr = right.expr
128-
if right.dtype == dtypes.BOOL_DTYPE:
129-
right_expr = sge.Cast(this=right_expr, to="INT64")
141+
left_expr, right_expr = _coerce_bools(left, right)
130142
return sge.Sub(this=left_expr, expression=right_expr)
131143

132144
if (
@@ -163,3 +175,16 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
163175
@BINARY_OP_REGISTRATION.register(ops.obj_make_ref_op)
164176
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
165177
return sge.func("OBJ.MAKE_REF", left.expr, right.expr)
178+
179+
180+
def _coerce_bools(
181+
left: TypedExpr, right: TypedExpr
182+
) -> tuple[sge.Expression, sge.Expression]:
183+
"""Coerce boolean expressions to INT64 for binary operations."""
184+
left_expr = left.expr
185+
if left.dtype == dtypes.BOOL_DTYPE:
186+
left_expr = sge.Cast(this=left_expr, to="INT64")
187+
right_expr = right.expr
188+
if right.dtype == dtypes.BOOL_DTYPE:
189+
right_expr = sge.Cast(this=right_expr, to="INT64")
190+
return left_expr, right_expr
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
COALESCE(CAST(`bfcol_1` AS STRING), '$NULL_SENTINEL$') = COALESCE(CAST(CAST(`bfcol_0` AS INT64) AS STRING), '$NULL_SENTINEL$') AS `bfcol_4`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_4` AS `int64_col`
14+
FROM `bfcte_1`
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_2` AS `bfcol_6`,
11+
`bfcol_1` AS `bfcol_7`,
12+
`bfcol_0` AS `bfcol_8`,
13+
`bfcol_1` = `bfcol_1` AS `bfcol_9`
14+
FROM `bfcte_0`
15+
), `bfcte_2` AS (
16+
SELECT
17+
*,
18+
`bfcol_6` AS `bfcol_14`,
19+
`bfcol_7` AS `bfcol_15`,
20+
`bfcol_8` AS `bfcol_16`,
21+
`bfcol_9` AS `bfcol_17`,
22+
`bfcol_7` = 1 AS `bfcol_18`
23+
FROM `bfcte_1`
24+
), `bfcte_3` AS (
25+
SELECT
26+
*,
27+
`bfcol_14` AS `bfcol_24`,
28+
`bfcol_15` AS `bfcol_25`,
29+
`bfcol_16` AS `bfcol_26`,
30+
`bfcol_17` AS `bfcol_27`,
31+
`bfcol_18` AS `bfcol_28`,
32+
`bfcol_15` = CAST(`bfcol_16` AS INT64) AS `bfcol_29`
33+
FROM `bfcte_2`
34+
), `bfcte_4` AS (
35+
SELECT
36+
*,
37+
`bfcol_24` AS `bfcol_36`,
38+
`bfcol_25` AS `bfcol_37`,
39+
`bfcol_26` AS `bfcol_38`,
40+
`bfcol_27` AS `bfcol_39`,
41+
`bfcol_28` AS `bfcol_40`,
42+
`bfcol_29` AS `bfcol_41`,
43+
CAST(`bfcol_26` AS INT64) = `bfcol_25` AS `bfcol_42`
44+
FROM `bfcte_3`
45+
)
46+
SELECT
47+
`bfcol_36` AS `rowindex`,
48+
`bfcol_37` AS `int64_col`,
49+
`bfcol_38` AS `bool_col`,
50+
`bfcol_39` AS `int_ne_int`,
51+
`bfcol_40` AS `int_ne_1`,
52+
`bfcol_41` AS `int_ne_bool`,
53+
`bfcol_42` AS `bool_ne_int`
54+
FROM `bfcte_4`
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_2` AS `bfcol_6`,
11+
`bfcol_1` AS `bfcol_7`,
12+
`bfcol_0` AS `bfcol_8`,
13+
`bfcol_1` <> `bfcol_1` AS `bfcol_9`
14+
FROM `bfcte_0`
15+
), `bfcte_2` AS (
16+
SELECT
17+
*,
18+
`bfcol_6` AS `bfcol_14`,
19+
`bfcol_7` AS `bfcol_15`,
20+
`bfcol_8` AS `bfcol_16`,
21+
`bfcol_9` AS `bfcol_17`,
22+
`bfcol_7` <> 1 AS `bfcol_18`
23+
FROM `bfcte_1`
24+
), `bfcte_3` AS (
25+
SELECT
26+
*,
27+
`bfcol_14` AS `bfcol_24`,
28+
`bfcol_15` AS `bfcol_25`,
29+
`bfcol_16` AS `bfcol_26`,
30+
`bfcol_17` AS `bfcol_27`,
31+
`bfcol_18` AS `bfcol_28`,
32+
`bfcol_15` <> CAST(`bfcol_16` AS INT64) AS `bfcol_29`
33+
FROM `bfcte_2`
34+
), `bfcte_4` AS (
35+
SELECT
36+
*,
37+
`bfcol_24` AS `bfcol_36`,
38+
`bfcol_25` AS `bfcol_37`,
39+
`bfcol_26` AS `bfcol_38`,
40+
`bfcol_27` AS `bfcol_39`,
41+
`bfcol_28` AS `bfcol_40`,
42+
`bfcol_29` AS `bfcol_41`,
43+
CAST(`bfcol_26` AS INT64) <> `bfcol_25` AS `bfcol_42`
44+
FROM `bfcte_3`
45+
)
46+
SELECT
47+
`bfcol_36` AS `rowindex`,
48+
`bfcol_37` AS `int64_col`,
49+
`bfcol_38` AS `bool_col`,
50+
`bfcol_39` AS `int_ne_int`,
51+
`bfcol_40` AS `int_ne_1`,
52+
`bfcol_41` AS `int_ne_bool`,
53+
`bfcol_42` AS `bool_ne_int`
54+
FROM `bfcte_4`

tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,24 @@ def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
102102
snapshot.assert_match(bf_df.sql, "out.sql")
103103

104104

105+
def test_eq_numeric(scalar_types_df: bpd.DataFrame, snapshot):
106+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
107+
108+
bf_df["int_ne_int"] = bf_df["int64_col"] == bf_df["int64_col"]
109+
bf_df["int_ne_1"] = bf_df["int64_col"] == 1
110+
111+
bf_df["int_ne_bool"] = bf_df["int64_col"] == bf_df["bool_col"]
112+
bf_df["bool_ne_int"] = bf_df["bool_col"] == bf_df["int64_col"]
113+
114+
snapshot.assert_match(bf_df.sql, "out.sql")
115+
116+
117+
def test_eq_null_match(scalar_types_df: bpd.DataFrame, snapshot):
118+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
119+
sql = _apply_binary_op(bf_df, ops.eq_null_match_op, "int64_col", "bool_col")
120+
snapshot.assert_match(sql, "out.sql")
121+
122+
105123
def test_json_set(json_types_df: bpd.DataFrame, snapshot):
106124
bf_df = json_types_df[["json_col"]]
107125
sql = _apply_binary_op(
@@ -169,3 +187,15 @@ def test_mul_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
169187
def test_obj_make_ref(scalar_types_df: bpd.DataFrame, snapshot):
170188
blob_df = scalar_types_df["string_col"].str.to_blob()
171189
snapshot.assert_match(blob_df.to_frame().sql, "out.sql")
190+
191+
192+
def test_ne_numeric(scalar_types_df: bpd.DataFrame, snapshot):
193+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
194+
195+
bf_df["int_ne_int"] = bf_df["int64_col"] != bf_df["int64_col"]
196+
bf_df["int_ne_1"] = bf_df["int64_col"] != 1
197+
198+
bf_df["int_ne_bool"] = bf_df["int64_col"] != bf_df["bool_col"]
199+
bf_df["bool_ne_int"] = bf_df["bool_col"] != bf_df["int64_col"]
200+
201+
snapshot.assert_match(bf_df.sql, "out.sql")

0 commit comments

Comments
 (0)