Skip to content

Commit b6aeca3

Browse files
authored
refactor: add apply_window_if_present and get_window_order_by methods (#1947)
Fixes internal issue 430350912
1 parent fd72b4e commit b6aeca3

File tree

6 files changed

+305
-41
lines changed

6 files changed

+305
-41
lines changed

bigframes/core/compile/sqlglot/aggregations/nullary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from bigframes.core import window_spec
2222
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23-
from bigframes.core.compile.sqlglot.aggregations.utils import apply_window_if_present
23+
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
2424
from bigframes.operations import aggregations as agg_ops
2525

2626
NULLARY_OP_REGISTRATION = reg.OpRegistration()

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
from bigframes.core import window_spec
2222
import bigframes.core.compile.sqlglot.aggregations.op_registration as reg
23-
from bigframes.core.compile.sqlglot.aggregations.utils import apply_window_if_present
23+
from bigframes.core.compile.sqlglot.aggregations.windows import apply_window_if_present
2424
import bigframes.core.compile.sqlglot.expressions.typed_expr as typed_expr
2525
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
2626
from bigframes.operations import aggregations as agg_ops

bigframes/core/compile/sqlglot/aggregations/utils.py

Lines changed: 0 additions & 29 deletions
This file was deleted.
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
import typing
17+
18+
import sqlglot.expressions as sge
19+
20+
from bigframes.core import utils, window_spec
21+
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
22+
import bigframes.core.ordering as ordering_spec
23+
24+
25+
def apply_window_if_present(
26+
value: sge.Expression,
27+
window: typing.Optional[window_spec.WindowSpec] = None,
28+
) -> sge.Expression:
29+
if window is None:
30+
return value
31+
32+
if window.is_row_bounded and not window.ordering:
33+
raise ValueError("No ordering provided for ordered analytic function")
34+
elif (
35+
not window.is_row_bounded
36+
and not window.is_range_bounded
37+
and not window.ordering
38+
):
39+
# Unbound grouping window.
40+
order_by = None
41+
elif window.is_range_bounded:
42+
# Note that, when the window is range-bounded, we only need one ordering key.
43+
# There are two reasons:
44+
# 1. Manipulating null positions requires more than one ordering key, which
45+
# is forbidden by SQL window syntax for range rolling.
46+
# 2. Pandas does not allow range rolling on timeseries with nulls.
47+
order_by = get_window_order_by((window.ordering[0],), override_null_order=False)
48+
else:
49+
order_by = get_window_order_by(window.ordering, override_null_order=True)
50+
51+
order = sge.Order(expressions=order_by) if order_by else None
52+
53+
group_by = (
54+
[scalar_compiler.compile_scalar_expression(key) for key in window.grouping_keys]
55+
if window.grouping_keys
56+
else None
57+
)
58+
59+
# This is the key change. Don't create a spec for the default window frame
60+
# if there's no ordering. This avoids generating an `ORDER BY NULL` clause.
61+
if not window.bounds and not order:
62+
return sge.Window(this=value, partition_by=group_by)
63+
64+
kind = (
65+
"ROWS" if isinstance(window.bounds, window_spec.RowsWindowBounds) else "RANGE"
66+
)
67+
68+
start: typing.Union[int, float, None] = None
69+
end: typing.Union[int, float, None] = None
70+
if isinstance(window.bounds, window_spec.RangeWindowBounds):
71+
if window.bounds.start is not None:
72+
start = utils.timedelta_to_micros(window.bounds.start)
73+
if window.bounds.end is not None:
74+
end = utils.timedelta_to_micros(window.bounds.end)
75+
elif window.bounds:
76+
start = window.bounds.start
77+
end = window.bounds.end
78+
79+
start_value, start_side = _get_window_bounds(start, is_preceding=True)
80+
end_value, end_side = _get_window_bounds(end, is_preceding=False)
81+
82+
spec = sge.WindowSpec(
83+
kind=kind,
84+
start=start_value,
85+
start_side=start_side,
86+
end=end_value,
87+
end_side=end_side,
88+
over="OVER",
89+
)
90+
91+
return sge.Window(this=value, partition_by=group_by, order=order, spec=spec)
92+
93+
94+
def get_window_order_by(
95+
ordering: typing.Tuple[ordering_spec.OrderingExpression, ...],
96+
override_null_order: bool = False,
97+
) -> typing.Optional[tuple[sge.Ordered, ...]]:
98+
"""Returns the SQL order by clause for a window specification."""
99+
if not ordering:
100+
return None
101+
102+
order_by = []
103+
for ordering_spec_item in ordering:
104+
expr = scalar_compiler.compile_scalar_expression(
105+
ordering_spec_item.scalar_expression
106+
)
107+
desc = not ordering_spec_item.direction.is_ascending
108+
nulls_first = not ordering_spec_item.na_last
109+
110+
if override_null_order:
111+
# Bigquery SQL considers NULLS to be "smallest" values, but we need
112+
# to override in these cases.
113+
is_null_expr = sge.Is(this=expr, expression=sge.Null())
114+
if nulls_first and desc:
115+
order_by.append(
116+
sge.Ordered(
117+
this=is_null_expr,
118+
desc=desc,
119+
nulls_first=nulls_first,
120+
)
121+
)
122+
elif not nulls_first and not desc:
123+
order_by.append(
124+
sge.Ordered(
125+
this=is_null_expr,
126+
desc=desc,
127+
nulls_first=nulls_first,
128+
)
129+
)
130+
131+
order_by.append(
132+
sge.Ordered(
133+
this=expr,
134+
desc=desc,
135+
nulls_first=nulls_first,
136+
)
137+
)
138+
return tuple(order_by)
139+
140+
141+
def _get_window_bounds(
142+
value, is_preceding: bool
143+
) -> tuple[typing.Union[str, sge.Expression], typing.Optional[str]]:
144+
"""Compiles a single boundary value into its SQL components."""
145+
if value is None:
146+
side = "PRECEDING" if is_preceding else "FOLLOWING"
147+
return "UNBOUNDED", side
148+
149+
if value == 0:
150+
return "CURRENT ROW", None
151+
152+
side = "PRECEDING" if value < 0 else "FOLLOWING"
153+
return sge.convert(abs(value)), side

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from bigframes.core import expression, guid, identifiers, nodes, pyarrow_utils, rewrite
2424
from bigframes.core.compile import configs
2525
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
26+
from bigframes.core.compile.sqlglot.aggregations import windows
2627
from bigframes.core.compile.sqlglot.expressions import typed_expr
2728
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
2829
import bigframes.core.compile.sqlglot.sqlglot_ir as ir
@@ -272,18 +273,16 @@ def compile_random_sample(
272273
def compile_aggregate(
273274
self, node: nodes.AggregateNode, child: ir.SQLGlotIR
274275
) -> ir.SQLGlotIR:
275-
ordering_cols = tuple(
276-
sge.Ordered(
277-
this=scalar_compiler.compile_scalar_expression(
278-
ordering.scalar_expression
279-
),
280-
desc=ordering.direction.is_ascending is False,
281-
nulls_first=ordering.na_last is False,
282-
)
283-
for ordering in node.order_by
276+
ordering_cols = windows.get_window_order_by(
277+
node.order_by, override_null_order=True
284278
)
285279
aggregations: tuple[tuple[str, sge.Expression], ...] = tuple(
286-
(id.sql, aggregate_compiler.compile_aggregate(agg, order_by=ordering_cols))
280+
(
281+
id.sql,
282+
aggregate_compiler.compile_aggregate(
283+
agg, order_by=ordering_cols if ordering_cols else ()
284+
),
285+
)
287286
for agg, id in node.aggregations
288287
)
289288
by_cols: tuple[sge.Expression, ...] = tuple(
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import pandas as pd
18+
import pytest
19+
import sqlglot.expressions as sge
20+
21+
from bigframes.core import window_spec
22+
from bigframes.core.compile.sqlglot.aggregations.windows import (
23+
apply_window_if_present,
24+
get_window_order_by,
25+
)
26+
import bigframes.core.expression as ex
27+
import bigframes.core.ordering as ordering
28+
29+
30+
class WindowsTest(unittest.TestCase):
31+
def test_get_window_order_by_empty(self):
32+
self.assertIsNone(get_window_order_by(tuple()))
33+
34+
def test_get_window_order_by(self):
35+
result = get_window_order_by((ordering.OrderingExpression(ex.deref("col1")),))
36+
self.assertEqual(
37+
sge.Order(expressions=result).sql(dialect="bigquery"),
38+
"ORDER BY `col1` ASC NULLS LAST",
39+
)
40+
41+
def test_get_window_order_by_override_nulls(self):
42+
result = get_window_order_by(
43+
(ordering.OrderingExpression(ex.deref("col1")),),
44+
override_null_order=True,
45+
)
46+
self.assertEqual(
47+
sge.Order(expressions=result).sql(dialect="bigquery"),
48+
"ORDER BY `col1` IS NULL ASC NULLS LAST, `col1` ASC NULLS LAST",
49+
)
50+
51+
def test_get_window_order_by_override_nulls_desc(self):
52+
result = get_window_order_by(
53+
(
54+
ordering.OrderingExpression(
55+
ex.deref("col1"),
56+
direction=ordering.OrderingDirection.DESC,
57+
na_last=False,
58+
),
59+
),
60+
override_null_order=True,
61+
)
62+
self.assertEqual(
63+
sge.Order(expressions=result).sql(dialect="bigquery"),
64+
"ORDER BY `col1` IS NULL DESC NULLS FIRST, `col1` DESC NULLS FIRST",
65+
)
66+
67+
def test_apply_window_if_present_no_window(self):
68+
value = sge.func(
69+
"SUM", sge.Column(this=sge.to_identifier("col_0", quoted=True))
70+
)
71+
result = apply_window_if_present(value)
72+
self.assertEqual(result, value)
73+
74+
def test_apply_window_if_present_row_bounded_no_ordering_raises(self):
75+
with pytest.raises(
76+
ValueError, match="No ordering provided for ordered analytic function"
77+
):
78+
apply_window_if_present(
79+
sge.Var(this="value"),
80+
window_spec.WindowSpec(
81+
bounds=window_spec.RowsWindowBounds(start=-1, end=1)
82+
),
83+
)
84+
85+
def test_apply_window_if_present_unbounded_grouping_no_ordering(self):
86+
result = apply_window_if_present(
87+
sge.Var(this="value"),
88+
window_spec.WindowSpec(
89+
grouping_keys=(ex.deref("col1"),),
90+
),
91+
)
92+
self.assertEqual(
93+
result.sql(dialect="bigquery"),
94+
"value OVER (PARTITION BY `col1`)",
95+
)
96+
97+
def test_apply_window_if_present_range_bounded(self):
98+
result = apply_window_if_present(
99+
sge.Var(this="value"),
100+
window_spec.WindowSpec(
101+
ordering=(ordering.OrderingExpression(ex.deref("col1")),),
102+
bounds=window_spec.RangeWindowBounds(start=None, end=pd.Timedelta(0)),
103+
),
104+
)
105+
self.assertEqual(
106+
result.sql(dialect="bigquery"),
107+
"value OVER (ORDER BY `col1` ASC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)",
108+
)
109+
110+
def test_apply_window_if_present_range_bounded_timedelta(self):
111+
result = apply_window_if_present(
112+
sge.Var(this="value"),
113+
window_spec.WindowSpec(
114+
ordering=(ordering.OrderingExpression(ex.deref("col1")),),
115+
bounds=window_spec.RangeWindowBounds(
116+
start=pd.Timedelta(days=-1), end=pd.Timedelta(hours=12)
117+
),
118+
),
119+
)
120+
self.assertEqual(
121+
result.sql(dialect="bigquery"),
122+
"value OVER (ORDER BY `col1` ASC NULLS LAST RANGE BETWEEN 86400000000 PRECEDING AND 43200000000 FOLLOWING)",
123+
)
124+
125+
def test_apply_window_if_present_all_params(self):
126+
result = apply_window_if_present(
127+
sge.Var(this="value"),
128+
window_spec.WindowSpec(
129+
grouping_keys=(ex.deref("col1"),),
130+
ordering=(ordering.OrderingExpression(ex.deref("col2")),),
131+
bounds=window_spec.RowsWindowBounds(start=-1, end=0),
132+
),
133+
)
134+
self.assertEqual(
135+
result.sql(dialect="bigquery"),
136+
"value OVER (PARTITION BY `col1` ORDER BY `col2` IS NULL ASC NULLS LAST, `col2` ASC NULLS LAST ROWS BETWEEN 1 PRECEDING AND CURRENT ROW)",
137+
)
138+
139+
140+
if __name__ == "__main__":
141+
unittest.main()

0 commit comments

Comments
 (0)