Skip to content

Commit 615a620

Browse files
authored
refactor: support agg_ops.LastOp, LastNonNullOp, FirstOp, FirstNonNullOp in the sqlglot compiler (#2153)
1 parent a410d0a commit 615a620

File tree

8 files changed

+186
-6
lines changed

8 files changed

+186
-6
lines changed

bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,4 @@ def _(
5050
if window is None:
5151
# ROW_NUMBER always needs an OVER clause.
5252
return sge.Window(this=result)
53-
return apply_window_if_present(result, window)
53+
return apply_window_if_present(result, window, include_framing_clauses=False)

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,51 @@ def _(
104104
column: typed_expr.TypedExpr,
105105
window: typing.Optional[window_spec.WindowSpec] = None,
106106
) -> sge.Expression:
107-
return apply_window_if_present(sge.func("DENSE_RANK"), window)
107+
return apply_window_if_present(
108+
sge.func("DENSE_RANK"), window, include_framing_clauses=False
109+
)
110+
111+
112+
@UNARY_OP_REGISTRATION.register(agg_ops.FirstOp)
113+
def _(
114+
op: agg_ops.FirstOp,
115+
column: typed_expr.TypedExpr,
116+
window: typing.Optional[window_spec.WindowSpec] = None,
117+
) -> sge.Expression:
118+
# FIRST_VALUE in BQ respects nulls by default.
119+
return apply_window_if_present(sge.FirstValue(this=column.expr), window)
120+
121+
122+
@UNARY_OP_REGISTRATION.register(agg_ops.FirstNonNullOp)
123+
def _(
124+
op: agg_ops.FirstNonNullOp,
125+
column: typed_expr.TypedExpr,
126+
window: typing.Optional[window_spec.WindowSpec] = None,
127+
) -> sge.Expression:
128+
return apply_window_if_present(
129+
sge.IgnoreNulls(this=sge.FirstValue(this=column.expr)), window
130+
)
131+
132+
133+
@UNARY_OP_REGISTRATION.register(agg_ops.LastOp)
134+
def _(
135+
op: agg_ops.LastOp,
136+
column: typed_expr.TypedExpr,
137+
window: typing.Optional[window_spec.WindowSpec] = None,
138+
) -> sge.Expression:
139+
# LAST_VALUE in BQ respects nulls by default.
140+
return apply_window_if_present(sge.LastValue(this=column.expr), window)
141+
142+
143+
@UNARY_OP_REGISTRATION.register(agg_ops.LastNonNullOp)
144+
def _(
145+
op: agg_ops.LastNonNullOp,
146+
column: typed_expr.TypedExpr,
147+
window: typing.Optional[window_spec.WindowSpec] = None,
148+
) -> sge.Expression:
149+
return apply_window_if_present(
150+
sge.IgnoreNulls(this=sge.LastValue(this=column.expr)), window
151+
)
108152

109153

110154
@UNARY_OP_REGISTRATION.register(agg_ops.MaxOp)
@@ -182,7 +226,9 @@ def _(
182226
column: typed_expr.TypedExpr,
183227
window: typing.Optional[window_spec.WindowSpec] = None,
184228
) -> sge.Expression:
185-
return apply_window_if_present(sge.func("RANK"), window)
229+
return apply_window_if_present(
230+
sge.func("RANK"), window, include_framing_clauses=False
231+
)
186232

187233

188234
@UNARY_OP_REGISTRATION.register(agg_ops.SizeUnaryOp)

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
def apply_window_if_present(
2626
value: sge.Expression,
2727
window: typing.Optional[window_spec.WindowSpec] = None,
28+
include_framing_clauses: bool = True,
2829
) -> sge.Expression:
2930
if window is None:
3031
return value
@@ -64,11 +65,11 @@ def apply_window_if_present(
6465
if not window.bounds and not order:
6566
return sge.Window(this=value, partition_by=group_by)
6667

67-
if not window.bounds:
68+
if not window.bounds and not include_framing_clauses:
6869
return sge.Window(this=value, partition_by=group_by, order=order)
6970

7071
kind = (
71-
"ROWS" if isinstance(window.bounds, window_spec.RowsWindowBounds) else "RANGE"
72+
"RANGE" if isinstance(window.bounds, window_spec.RangeWindowBounds) else "ROWS"
7273
)
7374

7475
start: typing.Union[int, float, None] = None
@@ -125,7 +126,7 @@ def get_window_order_by(
125126
nulls_first=nulls_first,
126127
)
127128
)
128-
elif not nulls_first and not desc:
129+
elif (not nulls_first) and (not desc):
129130
order_by.append(
130131
sge.Ordered(
131132
this=is_null_expr,
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN `bfcol_0` IS NULL
10+
THEN NULL
11+
ELSE FIRST_VALUE(`bfcol_0`) OVER (
12+
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
13+
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
14+
)
15+
END AS `bfcol_1`
16+
FROM `bfcte_0`
17+
)
18+
SELECT
19+
`bfcol_1` AS `agg_int64`
20+
FROM `bfcte_1`
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
FIRST_VALUE(`bfcol_0` IGNORE NULLS) OVER (
9+
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
10+
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
11+
) AS `bfcol_1`
12+
FROM `bfcte_0`
13+
)
14+
SELECT
15+
`bfcol_1` AS `agg_int64`
16+
FROM `bfcte_1`
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
CASE
9+
WHEN `bfcol_0` IS NULL
10+
THEN NULL
11+
ELSE LAST_VALUE(`bfcol_0`) OVER (
12+
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
13+
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
14+
)
15+
END AS `bfcol_1`
16+
FROM `bfcte_0`
17+
)
18+
SELECT
19+
`bfcol_1` AS `agg_int64`
20+
FROM `bfcte_1`
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
LAST_VALUE(`bfcol_0` IGNORE NULLS) OVER (
9+
ORDER BY `bfcol_0` IS NULL ASC NULLS LAST, `bfcol_0` ASC NULLS LAST
10+
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
11+
) AS `bfcol_1`
12+
FROM `bfcte_0`
13+
)
14+
SELECT
15+
`bfcol_1` AS `agg_int64`
16+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
1516
import typing
1617

1718
import pytest
@@ -126,6 +127,66 @@ def test_dense_rank(scalar_types_df: bpd.DataFrame, snapshot):
126127
snapshot.assert_match(sql, "out.sql")
127128

128129

130+
def test_first(scalar_types_df: bpd.DataFrame, snapshot):
131+
if sys.version_info < (3, 12):
132+
pytest.skip(
133+
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
134+
)
135+
col_name = "int64_col"
136+
bf_df = scalar_types_df[[col_name]]
137+
agg_expr = agg_exprs.UnaryAggregation(agg_ops.FirstOp(), expression.deref(col_name))
138+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
139+
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")
140+
141+
snapshot.assert_match(sql, "out.sql")
142+
143+
144+
def test_first_non_null(scalar_types_df: bpd.DataFrame, snapshot):
145+
if sys.version_info < (3, 12):
146+
pytest.skip(
147+
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
148+
)
149+
col_name = "int64_col"
150+
bf_df = scalar_types_df[[col_name]]
151+
agg_expr = agg_exprs.UnaryAggregation(
152+
agg_ops.FirstNonNullOp(), expression.deref(col_name)
153+
)
154+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
155+
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")
156+
157+
snapshot.assert_match(sql, "out.sql")
158+
159+
160+
def test_last(scalar_types_df: bpd.DataFrame, snapshot):
161+
if sys.version_info < (3, 12):
162+
pytest.skip(
163+
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
164+
)
165+
col_name = "int64_col"
166+
bf_df = scalar_types_df[[col_name]]
167+
agg_expr = agg_exprs.UnaryAggregation(agg_ops.LastOp(), expression.deref(col_name))
168+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
169+
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")
170+
171+
snapshot.assert_match(sql, "out.sql")
172+
173+
174+
def test_last_non_null(scalar_types_df: bpd.DataFrame, snapshot):
175+
if sys.version_info < (3, 12):
176+
pytest.skip(
177+
"Skipping test due to inconsistent SQL formatting on Python < 3.12.",
178+
)
179+
col_name = "int64_col"
180+
bf_df = scalar_types_df[[col_name]]
181+
agg_expr = agg_exprs.UnaryAggregation(
182+
agg_ops.LastNonNullOp(), expression.deref(col_name)
183+
)
184+
window = window_spec.WindowSpec(ordering=(ordering.ascending_over(col_name),))
185+
sql = _apply_unary_window_op(bf_df, agg_expr, window, "agg_int64")
186+
187+
snapshot.assert_match(sql, "out.sql")
188+
189+
129190
def test_max(scalar_types_df: bpd.DataFrame, snapshot):
130191
col_name = "int64_col"
131192
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)