Skip to content

Commit 8f9cbc3

Browse files
authored
refactor: add agg_ops.TimeSeriesDiffOp and DateSeriesDiffOp to sqlglot compiler (#2164)
1 parent e0aa9cc commit 8f9cbc3

File tree

4 files changed

+94
-0
lines changed

4 files changed

+94
-0
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,27 @@ def _(
9898
return apply_window_if_present(sge.func("COUNT", column.expr), window)
9999

100100

101+
@UNARY_OP_REGISTRATION.register(agg_ops.DateSeriesDiffOp)
102+
def _(
103+
op: agg_ops.DateSeriesDiffOp,
104+
column: typed_expr.TypedExpr,
105+
window: typing.Optional[window_spec.WindowSpec] = None,
106+
) -> sge.Expression:
107+
if column.dtype != dtypes.DATE_DTYPE:
108+
raise TypeError(f"Cannot perform date series diff on type {column.dtype}")
109+
shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)]
110+
shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window)
111+
# Conversion factor from days to microseconds
112+
conversion_factor = 24 * 60 * 60 * 1_000_000
113+
return sge.Cast(
114+
this=sge.DateDiff(
115+
this=column.expr, expression=shifted, unit=sge.Identifier(this="DAY")
116+
)
117+
* sge.convert(conversion_factor),
118+
to="INT64",
119+
)
120+
121+
101122
@UNARY_OP_REGISTRATION.register(agg_ops.DenseRankOp)
102123
def _(
103124
op: agg_ops.DenseRankOp,
@@ -293,3 +314,20 @@ def _(
293314
# Will be null if all inputs are null. Pandas defaults to zero sum though.
294315
zero = pd.to_timedelta(0) if column.dtype == dtypes.TIMEDELTA_DTYPE else 0
295316
return sge.func("IFNULL", expr, ir._literal(zero, column.dtype))
317+
318+
319+
@UNARY_OP_REGISTRATION.register(agg_ops.TimeSeriesDiffOp)
320+
def _(
321+
op: agg_ops.TimeSeriesDiffOp,
322+
column: typed_expr.TypedExpr,
323+
window: typing.Optional[window_spec.WindowSpec] = None,
324+
) -> sge.Expression:
325+
if column.dtype != dtypes.TIMESTAMP_DTYPE:
326+
raise TypeError(f"Cannot perform time series diff on type {column.dtype}")
327+
shift_op_impl = UNARY_OP_REGISTRATION[agg_ops.ShiftOp(0)]
328+
shifted = shift_op_impl(agg_ops.ShiftOp(op.periods), column, window)
329+
return sge.TimestampDiff(
330+
this=column.expr,
331+
expression=shifted,
332+
unit=sge.Identifier(this="MICROSECOND"),
333+
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`date_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CAST(DATE_DIFF(
9+
`bfcol_0`,
10+
LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST),
11+
DAY
12+
) * 86400000000 AS INT64) AS `bfcol_1`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `diff_date`
17+
FROM `bfcte_1`
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`timestamp_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
TIMESTAMP_DIFF(
9+
`bfcol_0`,
10+
LAG(`bfcol_0`, 1) OVER (ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST),
11+
MICROSECOND
12+
) AS `bfcol_1`
13+
FROM `bfcte_0`
14+
)
15+
SELECT
16+
`bfcol_1` AS `diff_time`
17+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,17 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
127127
snapshot.assert_match(sql, "out.sql")
128128

129129

130+
def test_date_series_diff(scalar_types_df: bpd.DataFrame, snapshot):
131+
col_name = "date_col"
132+
bf_df = scalar_types_df[[col_name]]
133+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
134+
op = agg_exprs.UnaryAggregation(
135+
agg_ops.DateSeriesDiffOp(periods=1), expression.deref(col_name)
136+
)
137+
sql = _apply_unary_window_op(bf_df, op, window, "diff_date")
138+
snapshot.assert_match(sql, "out.sql")
139+
140+
130141
def test_diff(scalar_types_df: bpd.DataFrame, snapshot):
131142
# Test integer
132143
int_col = "int64_col"
@@ -331,3 +342,14 @@ def test_sum(scalar_types_df: bpd.DataFrame, snapshot):
331342
)
332343

333344
snapshot.assert_match(sql, "out.sql")
345+
346+
347+
def test_time_series_diff(scalar_types_df: bpd.DataFrame, snapshot):
348+
col_name = "timestamp_col"
349+
bf_df = scalar_types_df[[col_name]]
350+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
351+
op = agg_exprs.UnaryAggregation(
352+
agg_ops.TimeSeriesDiffOp(periods=1), expression.deref(col_name)
353+
)
354+
sql = _apply_unary_window_op(bf_df, op, window, "diff_time")
355+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)