Skip to content

Commit e95dc2c

Browse files
authored
refactor: support ops.case_when_op and fix invert_op for the sqlglot compiler (#2174)
1 parent 0a44e84 commit e95dc2c

File tree

8 files changed

+152
-43
lines changed

8 files changed

+152
-43
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2424

2525
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
26+
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
2627
register_ternary_op = scalar_compiler.scalar_op_compiler.register_ternary_op
2728

2829

@@ -67,23 +68,18 @@ def _(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
6768
return _cast(sg_expr, sg_to_type, op.safe)
6869

6970

70-
@register_ternary_op(ops.clip_op)
71-
def _(
72-
original: TypedExpr,
73-
lower: TypedExpr,
74-
upper: TypedExpr,
75-
) -> sge.Expression:
76-
return sge.Greatest(
77-
this=sge.Least(this=original.expr, expressions=[upper.expr]),
78-
expressions=[lower.expr],
79-
)
80-
81-
8271
@register_unary_op(ops.hash_op)
8372
def _(expr: TypedExpr) -> sge.Expression:
8473
return sge.func("FARM_FINGERPRINT", expr.expr)
8574

8675

76+
@register_unary_op(ops.invert_op)
77+
def _(expr: TypedExpr) -> sge.Expression:
78+
if expr.dtype == dtypes.BOOL_DTYPE:
79+
return sge.Not(this=expr.expr)
80+
return sge.BitwiseNot(this=expr.expr)
81+
82+
8783
@register_unary_op(ops.isnull_op)
8884
def _(expr: TypedExpr) -> sge.Expression:
8985
return sge.Is(this=expr.expr, expression=sge.Null())
@@ -114,6 +110,44 @@ def _(
114110
return sge.If(this=condition.expr, true=original.expr, false=replacement.expr)
115111

116112

113+
@register_ternary_op(ops.clip_op)
114+
def _(
115+
original: TypedExpr,
116+
lower: TypedExpr,
117+
upper: TypedExpr,
118+
) -> sge.Expression:
119+
return sge.Greatest(
120+
this=sge.Least(this=original.expr, expressions=[upper.expr]),
121+
expressions=[lower.expr],
122+
)
123+
124+
125+
@register_nary_op(ops.case_when_op)
126+
def _(*cases_and_outputs: TypedExpr) -> sge.Expression:
127+
# Need to upcast BOOL to INT if any output is numeric
128+
result_values = cases_and_outputs[1::2]
129+
do_upcast_bool = any(
130+
dtypes.is_numeric(t.dtype, include_bool=False) for t in result_values
131+
)
132+
if do_upcast_bool:
133+
result_values = tuple(
134+
TypedExpr(
135+
sge.Cast(this=val.expr, to="INT64"),
136+
dtypes.INT_DTYPE,
137+
)
138+
if val.dtype == dtypes.BOOL_DTYPE
139+
else val
140+
for val in result_values
141+
)
142+
143+
return sge.Case(
144+
ifs=[
145+
sge.If(this=predicate.expr, true=output.expr)
146+
for predicate, output in zip(cases_and_outputs[::2], result_values)
147+
],
148+
)
149+
150+
117151
# Helper functions
118152
def _cast_to_json(expr: TypedExpr, op: ops.AsTypeOp) -> sge.Expression:
119153
from_type = expr.dtype

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,6 @@ def _(expr: TypedExpr) -> sge.Expression:
148148
return sge.Floor(this=expr.expr)
149149

150150

151-
@register_unary_op(ops.invert_op)
152-
def _(expr: TypedExpr) -> sge.Expression:
153-
return sge.BitwiseNot(this=expr.expr)
154-
155-
156151
@register_unary_op(ops.ln_op)
157152
def _(expr: TypedExpr) -> sge.Expression:
158153
return sge.Case(

tests/system/small/engines/test_generic_ops.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def test_engines_fillna_op(scalars_array_value: array_value.ArrayValue, engine):
357357
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
358358

359359

360-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
360+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
361361
def test_engines_casewhen_op_single_case(
362362
scalars_array_value: array_value.ArrayValue, engine
363363
):
@@ -373,7 +373,7 @@ def test_engines_casewhen_op_single_case(
373373
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
374374

375375

376-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
376+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
377377
def test_engines_casewhen_op_double_case(
378378
scalars_array_value: array_value.ArrayValue, engine
379379
):
@@ -391,7 +391,7 @@ def test_engines_casewhen_op_double_case(
391391
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
392392

393393

394-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
394+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
395395
def test_engines_isnull_op(scalars_array_value: array_value.ArrayValue, engine):
396396
arr, _ = scalars_array_value.compute_values(
397397
[ops.isnull_op.as_expr(expression.deref("string_col"))]
@@ -400,7 +400,7 @@ def test_engines_isnull_op(scalars_array_value: array_value.ArrayValue, engine):
400400
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
401401

402402

403-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
403+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
404404
def test_engines_notnull_op(scalars_array_value: array_value.ArrayValue, engine):
405405
arr, _ = scalars_array_value.compute_values(
406406
[ops.notnull_op.as_expr(expression.deref("string_col"))]
@@ -409,7 +409,7 @@ def test_engines_notnull_op(scalars_array_value: array_value.ArrayValue, engine)
409409
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
410410

411411

412-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
412+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
413413
def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine):
414414
arr, _ = scalars_array_value.compute_values(
415415
[
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`int64_col` AS `bfcol_1`,
5+
`int64_too` AS `bfcol_2`,
6+
`float64_col` AS `bfcol_3`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
8+
), `bfcte_1` AS (
9+
SELECT
10+
*,
11+
CASE WHEN `bfcol_0` THEN `bfcol_1` END AS `bfcol_4`,
12+
CASE WHEN `bfcol_0` THEN `bfcol_1` WHEN `bfcol_0` THEN `bfcol_2` END AS `bfcol_5`,
13+
CASE WHEN `bfcol_0` THEN `bfcol_0` WHEN `bfcol_0` THEN `bfcol_0` END AS `bfcol_6`,
14+
CASE
15+
WHEN `bfcol_0`
16+
THEN `bfcol_1`
17+
WHEN `bfcol_0`
18+
THEN CAST(`bfcol_0` AS INT64)
19+
WHEN `bfcol_0`
20+
THEN `bfcol_3`
21+
END AS `bfcol_7`
22+
FROM `bfcte_0`
23+
)
24+
SELECT
25+
`bfcol_4` AS `single_case`,
26+
`bfcol_5` AS `double_case`,
27+
`bfcol_6` AS `bool_types_case`,
28+
`bfcol_7` AS `mixed_types_cast`
29+
FROM `bfcte_1`
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`bytes_col` AS `bfcol_1`,
5+
`int64_col` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
~`bfcol_2` AS `bfcol_6`,
11+
~`bfcol_1` AS `bfcol_7`,
12+
NOT `bfcol_0` AS `bfcol_8`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_6` AS `int64_col`,
17+
`bfcol_7` AS `bytes_col`,
18+
`bfcol_8` AS `bool_col`
19+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_invert/out.sql

Lines changed: 0 additions & 13 deletions
This file was deleted.

tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,47 @@ def test_astype_json_invalid(
168168
)
169169

170170

171+
def test_case_when_op(scalar_types_df: bpd.DataFrame, snapshot):
172+
ops_map = {
173+
"single_case": ops.case_when_op.as_expr(
174+
"bool_col",
175+
"int64_col",
176+
),
177+
"double_case": ops.case_when_op.as_expr(
178+
"bool_col",
179+
"int64_col",
180+
"bool_col",
181+
"int64_too",
182+
),
183+
"bool_types_case": ops.case_when_op.as_expr(
184+
"bool_col",
185+
"bool_col",
186+
"bool_col",
187+
"bool_col",
188+
),
189+
"mixed_types_cast": ops.case_when_op.as_expr(
190+
"bool_col",
191+
"int64_col",
192+
"bool_col",
193+
"bool_col",
194+
"bool_col",
195+
"float64_col",
196+
),
197+
}
198+
199+
array_value = scalar_types_df._block.expr
200+
result, col_ids = array_value.compute_values(list(ops_map.values()))
201+
202+
# Rename columns for deterministic golden SQL results.
203+
assert len(col_ids) == len(ops_map.keys())
204+
result = result.rename_columns(
205+
{col_id: key for col_id, key in zip(col_ids, ops_map.keys())}
206+
).select_columns(list(ops_map.keys()))
207+
208+
sql = result.session._executor.to_sql(result, enable_cache=False)
209+
snapshot.assert_match(sql, "out.sql")
210+
211+
171212
def test_clip(scalar_types_df: bpd.DataFrame, snapshot):
172213
op_expr = ops.clip_op.as_expr("rowindex", "int64_col", "int64_too")
173214

@@ -192,6 +233,18 @@ def test_hash(scalar_types_df: bpd.DataFrame, snapshot):
192233
snapshot.assert_match(sql, "out.sql")
193234

194235

236+
def test_invert(scalar_types_df: bpd.DataFrame, snapshot):
237+
bf_df = scalar_types_df[["int64_col", "bytes_col", "bool_col"]]
238+
ops_map = {
239+
"int64_col": ops.invert_op.as_expr("int64_col"),
240+
"bytes_col": ops.invert_op.as_expr("bytes_col"),
241+
"bool_col": ops.invert_op.as_expr("bool_col"),
242+
}
243+
sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys()))
244+
245+
snapshot.assert_match(sql, "out.sql")
246+
247+
195248
def test_isnull(scalar_types_df: bpd.DataFrame, snapshot):
196249
col_name = "float64_col"
197250
bf_df = scalar_types_df[[col_name]]

tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -126,14 +126,6 @@ def test_floor(scalar_types_df: bpd.DataFrame, snapshot):
126126
snapshot.assert_match(sql, "out.sql")
127127

128128

129-
def test_invert(scalar_types_df: bpd.DataFrame, snapshot):
130-
col_name = "int64_col"
131-
bf_df = scalar_types_df[[col_name]]
132-
sql = utils._apply_unary_ops(bf_df, [ops.invert_op.as_expr(col_name)], [col_name])
133-
134-
snapshot.assert_match(sql, "out.sql")
135-
136-
137129
def test_ln(scalar_types_df: bpd.DataFrame, snapshot):
138130
col_name = "float64_col"
139131
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)