Skip to content

Commit 8714977

Browse files
authored
refactor: support agg_ops.ShiftOp and DiffOp for the sqlglot compiler (#2156)
1 parent 7cb9e47 commit 8714977

File tree

7 files changed

+152
-0
lines changed

7 files changed

+152
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,23 @@ def _(
151151
)
152152

153153

154+
@UNARY_OP_REGISTRATION.register(agg_ops.DiffOp)
155+
def _(
156+
op: agg_ops.DiffOp,
157+
column: typed_expr.TypedExpr,
158+
window: typing.Optional[window_spec.WindowSpec] = None,
159+
) -> sge.Expression:
160+
shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)]
161+
shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window)
162+
if column.dtype in (dtypes.BOOL_DTYPE, dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE):
163+
if column.dtype == dtypes.BOOL_DTYPE:
164+
return sge.NEQ(this=column.expr, expression=shifted)
165+
else:
166+
return sge.Sub(this=column.expr, expression=shifted)
167+
else:
168+
raise TypeError(f"Cannot perform diff on type {column.dtype}")
169+
170+
154171
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
155172
def _(
156173
op: agg_ops.MaxOp,
@@ -240,6 +257,27 @@ def _(
240257
return apply_window_if_present(sge.func("COUNT", sge.convert(1)), window)
241258

242259

260+
@UNARY_OP_REGISTRATION.register(agg_ops.ShiftOp)
261+
def _(
262+
op: agg_ops.ShiftOp,
263+
column: typed_expr.TypedExpr,
264+
window: typing.Optional[window_spec.WindowSpec] = None,
265+
) -> sge.Expression:
266+
if op.periods == 0: # No-op
267+
return column.expr
268+
if op.periods > 0:
269+
return apply_window_if_present(
270+
sge.func("LAG", column.expr, sge.convert(op.periods)),
271+
window,
272+
include_framing_clauses=False,
273+
)
274+
return apply_window_if_present(
275+
sge.func("LEAD", column.expr, sge.convert(-op.periods)),
276+
window,
277+
include_framing_clauses=False,
278+
)
279+
280+
243281
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
244282
def _(
245283
op: agg_ops.SumOp,
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
`bfcol_0` <> LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `diff_bool`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
`bfcol_0` - LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `diff_int`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `lag`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
LEAD(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST) AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `lead`
13+
FROM `bfcte_1`
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
`bfcol_0` AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `noop`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,28 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
127127
snapshot.assert_match(sql, "out.sql")
128128

129129

130+
def test_diff(scalar_types_df: bpd.DataFrame, snapshot):
131+
# Test integer
132+
int_col = "int64_col"
133+
bf_df_int = scalar_types_df[[int_col]]
134+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(int_col),))
135+
int_op = agg_exprs.UnaryAggregation(
136+
agg_ops.DiffOp(periods=1), expression.deref(int_col)
137+
)
138+
int_sql = _apply_unary_window_op(bf_df_int, int_op, window, "diff_int")
139+
snapshot.assert_match(int_sql, "diff_int.sql")
140+
141+
# Test boolean
142+
bool_col = "bool_col"
143+
bf_df_bool = scalar_types_df[[bool_col]]
144+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(bool_col),))
145+
bool_op = agg_exprs.UnaryAggregation(
146+
agg_ops.DiffOp(periods=1), expression.deref(bool_col)
147+
)
148+
bool_sql = _apply_unary_window_op(bf_df_bool, bool_op, window, "diff_bool")
149+
snapshot.assert_match(bool_sql, "diff_bool.sql")
150+
151+
130152
def test_first(scalar_types_df: bpd.DataFrame, snapshot):
131153
if sys.version_info < (3, 12):
132154
pytest.skip(
@@ -271,6 +293,33 @@ def test_rank(scalar_types_df: bpd.DataFrame, snapshot):
271293
snapshot.assert_match(sql, "out.sql")
272294

273295

296+
def test_shift(scalar_types_df: bpd.DataFrame, snapshot):
297+
col_name = "int64_col"
298+
bf_df = scalar_types_df[[col_name]]
299+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
300+
301+
# Test lag
302+
lag_op = agg_exprs.UnaryAggregation(
303+
agg_ops.ShiftOp(periods=1), expression.deref(col_name)
304+
)
305+
lag_sql = _apply_unary_window_op(bf_df, lag_op, window, "lag")
306+
snapshot.assert_match(lag_sql, "lag.sql")
307+
308+
# Test lead
309+
lead_op = agg_exprs.UnaryAggregation(
310+
agg_ops.ShiftOp(periods=-1), expression.deref(col_name)
311+
)
312+
lead_sql = _apply_unary_window_op(bf_df, lead_op, window, "lead")
313+
snapshot.assert_match(lead_sql, "lead.sql")
314+
315+
# Test no-op
316+
noop_op = agg_exprs.UnaryAggregation(
317+
agg_ops.ShiftOp(periods=0), expression.deref(col_name)
318+
)
319+
noop_sql = _apply_unary_window_op(bf_df, noop_op, window, "noop")
320+
snapshot.assert_match(noop_sql, "noop.sql")
321+
322+
274323
def test_sum(scalar_types_df: bpd.DataFrame, snapshot):
275324
bf_df = scalar_types_df[["int64_col", "bool_col"]]
276325
agg_ops_map = {

0 commit comments

Comments
 (0)