Skip to content

Commit ba93b5b

Browse files
authored
chore: implement mul_op and div_op compilers (#1987)
Fixes internal issue 430133370
1 parent 2720c4c commit ba93b5b

File tree

7 files changed

+250
-13
lines changed

7 files changed

+250
-13
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,51 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
7373
)
7474

7575

76+
@BINARY_OP_REGISTRATION.register(ops.div_op)
77+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
78+
left_expr = left.expr
79+
if left.dtype == dtypes.BOOL_DTYPE:
80+
left_expr = sge.Cast(this=left_expr, to="INT64")
81+
right_expr = right.expr
82+
if right.dtype == dtypes.BOOL_DTYPE:
83+
right_expr = sge.Cast(this=right_expr, to="INT64")
84+
85+
result = sge.func("IEEE_DIVIDE", left_expr, right_expr)
86+
if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype):
87+
return sge.Cast(this=sge.Floor(this=result), to="INT64")
88+
else:
89+
return result
90+
91+
92+
@BINARY_OP_REGISTRATION.register(ops.ge_op)
93+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
94+
return sge.GTE(this=left.expr, expression=right.expr)
95+
96+
97+
@BINARY_OP_REGISTRATION.register(ops.JSONSet)
98+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
99+
return sge.func("JSON_SET", left.expr, sge.convert(op.json_path), right.expr)
100+
101+
102+
@BINARY_OP_REGISTRATION.register(ops.mul_op)
103+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
104+
left_expr = left.expr
105+
if left.dtype == dtypes.BOOL_DTYPE:
106+
left_expr = sge.Cast(this=left_expr, to="INT64")
107+
right_expr = right.expr
108+
if right.dtype == dtypes.BOOL_DTYPE:
109+
right_expr = sge.Cast(this=right_expr, to="INT64")
110+
111+
result = sge.Mul(this=left_expr, expression=right_expr)
112+
113+
if (dtypes.is_numeric(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE) or (
114+
left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype)
115+
):
116+
return sge.Cast(this=sge.Floor(this=result), to="INT64")
117+
else:
118+
return result
119+
120+
76121
@BINARY_OP_REGISTRATION.register(ops.sub_op)
77122
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
78123
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
@@ -113,13 +158,3 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
113158
raise TypeError(
114159
f"Cannot subtract type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}"
115160
)
116-
117-
118-
@BINARY_OP_REGISTRATION.register(ops.ge_op)
119-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
120-
return sge.GTE(this=left.expr, expression=right.expr)
121-
122-
123-
@BINARY_OP_REGISTRATION.register(ops.JSONSet)
124-
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
125-
return sge.func("JSON_SET", left.expr, sge.convert(op.json_path), right.expr)

tests/system/small/engines/test_numeric_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def test_engines_project_sub(
7171
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
7272

7373

74-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
74+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
7575
def test_engines_project_mul(
7676
scalars_array_value: array_value.ArrayValue,
7777
engine,
@@ -80,7 +80,7 @@ def test_engines_project_mul(
8080
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
8181

8282

83-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
83+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
8484
def test_engines_project_div(scalars_array_value: array_value.ArrayValue, engine):
8585
# TODO: Duration div is sensitive to zeroes
8686
# TODO: Numeric col is sensitive to scale shifts
@@ -90,7 +90,7 @@ def test_engines_project_div(scalars_array_value: array_value.ArrayValue, engine
9090
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
9191

9292

93-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
93+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
9494
def test_engines_project_div_durations(
9595
scalars_array_value: array_value.ArrayValue, engine
9696
):
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_2` AS `bfcol_6`,
11+
`bfcol_1` AS `bfcol_7`,
12+
`bfcol_0` AS `bfcol_8`,
13+
IEEE_DIVIDE(`bfcol_1`, `bfcol_1`) AS `bfcol_9`
14+
FROM `bfcte_0`
15+
), `bfcte_2` AS (
16+
SELECT
17+
*,
18+
`bfcol_6` AS `bfcol_14`,
19+
`bfcol_7` AS `bfcol_15`,
20+
`bfcol_8` AS `bfcol_16`,
21+
`bfcol_9` AS `bfcol_17`,
22+
IEEE_DIVIDE(`bfcol_7`, 1) AS `bfcol_18`
23+
FROM `bfcte_1`
24+
), `bfcte_3` AS (
25+
SELECT
26+
*,
27+
`bfcol_14` AS `bfcol_24`,
28+
`bfcol_15` AS `bfcol_25`,
29+
`bfcol_16` AS `bfcol_26`,
30+
`bfcol_17` AS `bfcol_27`,
31+
`bfcol_18` AS `bfcol_28`,
32+
IEEE_DIVIDE(`bfcol_15`, CAST(`bfcol_16` AS INT64)) AS `bfcol_29`
33+
FROM `bfcte_2`
34+
), `bfcte_4` AS (
35+
SELECT
36+
*,
37+
`bfcol_24` AS `bfcol_36`,
38+
`bfcol_25` AS `bfcol_37`,
39+
`bfcol_26` AS `bfcol_38`,
40+
`bfcol_27` AS `bfcol_39`,
41+
`bfcol_28` AS `bfcol_40`,
42+
`bfcol_29` AS `bfcol_41`,
43+
IEEE_DIVIDE(CAST(`bfcol_26` AS INT64), `bfcol_25`) AS `bfcol_42`
44+
FROM `bfcte_3`
45+
)
46+
SELECT
47+
`bfcol_36` AS `rowindex`,
48+
`bfcol_37` AS `int64_col`,
49+
`bfcol_38` AS `bool_col`,
50+
`bfcol_39` AS `int_div_int`,
51+
`bfcol_40` AS `int_div_1`,
52+
`bfcol_41` AS `int_div_bool`,
53+
`bfcol_42` AS `bool_div_int`
54+
FROM `bfcte_4`
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`,
5+
`timestamp_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_1` AS `bfcol_6`,
11+
`bfcol_2` AS `bfcol_7`,
12+
`bfcol_0` AS `bfcol_8`,
13+
CAST(FLOOR(IEEE_DIVIDE(86400000000, `bfcol_0`)) AS INT64) AS `bfcol_9`
14+
FROM `bfcte_0`
15+
)
16+
SELECT
17+
`bfcol_6` AS `rowindex`,
18+
`bfcol_7` AS `timestamp_col`,
19+
`bfcol_8` AS `int64_col`,
20+
`bfcol_9` AS `timedelta_div_numeric`
21+
FROM `bfcte_1`
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_2` AS `bfcol_6`,
11+
`bfcol_1` AS `bfcol_7`,
12+
`bfcol_0` AS `bfcol_8`,
13+
`bfcol_1` * `bfcol_1` AS `bfcol_9`
14+
FROM `bfcte_0`
15+
), `bfcte_2` AS (
16+
SELECT
17+
*,
18+
`bfcol_6` AS `bfcol_14`,
19+
`bfcol_7` AS `bfcol_15`,
20+
`bfcol_8` AS `bfcol_16`,
21+
`bfcol_9` AS `bfcol_17`,
22+
`bfcol_7` * 1 AS `bfcol_18`
23+
FROM `bfcte_1`
24+
), `bfcte_3` AS (
25+
SELECT
26+
*,
27+
`bfcol_14` AS `bfcol_24`,
28+
`bfcol_15` AS `bfcol_25`,
29+
`bfcol_16` AS `bfcol_26`,
30+
`bfcol_17` AS `bfcol_27`,
31+
`bfcol_18` AS `bfcol_28`,
32+
`bfcol_15` * CAST(`bfcol_16` AS INT64) AS `bfcol_29`
33+
FROM `bfcte_2`
34+
), `bfcte_4` AS (
35+
SELECT
36+
*,
37+
`bfcol_24` AS `bfcol_36`,
38+
`bfcol_25` AS `bfcol_37`,
39+
`bfcol_26` AS `bfcol_38`,
40+
`bfcol_27` AS `bfcol_39`,
41+
`bfcol_28` AS `bfcol_40`,
42+
`bfcol_29` AS `bfcol_41`,
43+
CAST(`bfcol_26` AS INT64) * `bfcol_25` AS `bfcol_42`
44+
FROM `bfcte_3`
45+
)
46+
SELECT
47+
`bfcol_36` AS `rowindex`,
48+
`bfcol_37` AS `int64_col`,
49+
`bfcol_38` AS `bool_col`,
50+
`bfcol_39` AS `int_mul_int`,
51+
`bfcol_40` AS `int_mul_1`,
52+
`bfcol_41` AS `int_mul_bool`,
53+
`bfcol_42` AS `bool_mul_int`
54+
FROM `bfcte_4`
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`rowindex` AS `bfcol_1`,
5+
`timestamp_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_1` AS `bfcol_6`,
11+
`bfcol_2` AS `bfcol_7`,
12+
`bfcol_0` AS `bfcol_8`,
13+
CAST(FLOOR(86400000000 * `bfcol_0`) AS INT64) AS `bfcol_9`
14+
FROM `bfcte_0`
15+
), `bfcte_2` AS (
16+
SELECT
17+
*,
18+
`bfcol_6` AS `bfcol_14`,
19+
`bfcol_7` AS `bfcol_15`,
20+
`bfcol_8` AS `bfcol_16`,
21+
`bfcol_9` AS `bfcol_17`,
22+
CAST(FLOOR(`bfcol_8` * 86400000000) AS INT64) AS `bfcol_18`
23+
FROM `bfcte_1`
24+
)
25+
SELECT
26+
`bfcol_14` AS `rowindex`,
27+
`bfcol_15` AS `timestamp_col`,
28+
`bfcol_16` AS `int64_col`,
29+
`bfcol_17` AS `timedelta_mul_numeric`,
30+
`bfcol_18` AS `numeric_mul_timedelta`
31+
FROM `bfcte_2`

tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,26 @@ def test_add_unsupported_raises(scalar_types_df: bpd.DataFrame):
8282
_apply_binary_op(scalar_types_df, ops.add_op, "int64_col", "string_col")
8383

8484

85+
def test_div_numeric(scalar_types_df: bpd.DataFrame, snapshot):
86+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
87+
88+
bf_df["int_div_int"] = bf_df["int64_col"] / bf_df["int64_col"]
89+
bf_df["int_div_1"] = bf_df["int64_col"] / 1
90+
91+
bf_df["int_div_bool"] = bf_df["int64_col"] / bf_df["bool_col"]
92+
bf_df["bool_div_int"] = bf_df["bool_col"] / bf_df["int64_col"]
93+
94+
snapshot.assert_match(bf_df.sql, "out.sql")
95+
96+
97+
def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
98+
bf_df = scalar_types_df[["timestamp_col", "int64_col"]]
99+
timedelta = pd.Timedelta(1, unit="d")
100+
bf_df["timedelta_div_numeric"] = timedelta / bf_df["int64_col"]
101+
102+
snapshot.assert_match(bf_df.sql, "out.sql")
103+
104+
85105
def test_json_set(json_types_df: bpd.DataFrame, snapshot):
86106
bf_df = json_types_df[["json_col"]]
87107
sql = _apply_binary_op(
@@ -122,3 +142,25 @@ def test_sub_unsupported_raises(scalar_types_df: bpd.DataFrame):
122142

123143
with pytest.raises(TypeError):
124144
_apply_binary_op(scalar_types_df, ops.sub_op, "int64_col", "string_col")
145+
146+
147+
def test_mul_numeric(scalar_types_df: bpd.DataFrame, snapshot):
148+
bf_df = scalar_types_df[["int64_col", "bool_col"]]
149+
150+
bf_df["int_mul_int"] = bf_df["int64_col"] * bf_df["int64_col"]
151+
bf_df["int_mul_1"] = bf_df["int64_col"] * 1
152+
153+
bf_df["int_mul_bool"] = bf_df["int64_col"] * bf_df["bool_col"]
154+
bf_df["bool_mul_int"] = bf_df["bool_col"] * bf_df["int64_col"]
155+
156+
snapshot.assert_match(bf_df.sql, "out.sql")
157+
158+
159+
def test_mul_timedelta(scalar_types_df: bpd.DataFrame, snapshot):
160+
bf_df = scalar_types_df[["timestamp_col", "int64_col"]]
161+
timedelta = pd.Timedelta(1, unit="d")
162+
163+
bf_df["timedelta_mul_numeric"] = timedelta * bf_df["int64_col"]
164+
bf_df["numeric_mul_timedelta"] = bf_df["int64_col"] * timedelta
165+
166+
snapshot.assert_match(bf_df.sql, "out.sql")

0 commit comments

Comments
 (0)