Skip to content

Commit b023cb0

Browse files
authored
refactor: fix window and numeric/comparison ops for sqlglot compiler (#2372)
This change can resolve `test_series_int_int_operators_scalar` and all tests in test_windows.py presubmit failures in #2248. Fixes internal issue 417774347 🦕
1 parent cc994f3 commit b023cb0

File tree

8 files changed

+79
-15
lines changed

8 files changed

+79
-15
lines changed

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def apply_window_if_present(
4444
order_by = None
4545
elif window.is_range_bounded:
4646
order_by = get_window_order_by((window.ordering[0],))
47+
order_by = remove_null_ordering_for_range_windows(order_by)
4748
else:
4849
order_by = get_window_order_by(window.ordering)
4950

@@ -150,6 +151,30 @@ def get_window_order_by(
150151
return tuple(order_by)
151152

152153

154+
def remove_null_ordering_for_range_windows(
155+
order_by: typing.Optional[tuple[sge.Ordered, ...]],
156+
) -> typing.Optional[tuple[sge.Ordered, ...]]:
157+
"""Removes NULL FIRST/LAST from ORDER BY expressions in RANGE windows.
158+
Here's the support matrix:
159+
✅ sum(x) over (order by y desc nulls last)
160+
🚫 sum(x) over (order by y asc nulls last)
161+
✅ sum(x) over (order by y asc nulls first)
162+
🚫 sum(x) over (order by y desc nulls first)
163+
"""
164+
if order_by is None:
165+
return None
166+
167+
new_order_by = []
168+
for key in order_by:
169+
kargs = key.args
170+
if kargs.get("desc") is True and kargs.get("nulls_first", False):
171+
kargs["nulls_first"] = False
172+
elif kargs.get("desc") is False and not kargs.setdefault("nulls_first", True):
173+
kargs["nulls_first"] = True
174+
new_order_by.append(sge.Ordered(**kargs))
175+
return tuple(new_order_by)
176+
177+
153178
def _get_window_bounds(
154179
value, is_preceding: bool
155180
) -> tuple[typing.Union[str, sge.Expression], typing.Optional[str]]:

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,9 @@ def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotI
356356
observation_count = windows.apply_window_if_present(
357357
sge.func("SUM", is_observation), window_spec
358358
)
359+
observation_count = sge.func(
360+
"COALESCE", observation_count, sge.convert(0)
361+
)
359362
else:
360363
# Operations like count treat even NULLs as valid observations
361364
# for the sake of min_periods notnull is just used to convert

bigframes/core/compile/sqlglot/expressions/comparison_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,27 +89,39 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
8989

9090
@register_binary_op(ops.ge_op)
9191
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
92+
if left.expr == sge.null() or right.expr == sge.null():
93+
return sge.null()
94+
9295
left_expr = _coerce_bool_to_int(left)
9396
right_expr = _coerce_bool_to_int(right)
9497
return sge.GTE(this=left_expr, expression=right_expr)
9598

9699

97100
@register_binary_op(ops.gt_op)
98101
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
102+
if left.expr == sge.null() or right.expr == sge.null():
103+
return sge.null()
104+
99105
left_expr = _coerce_bool_to_int(left)
100106
right_expr = _coerce_bool_to_int(right)
101107
return sge.GT(this=left_expr, expression=right_expr)
102108

103109

104110
@register_binary_op(ops.lt_op)
105111
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
112+
if left.expr == sge.null() or right.expr == sge.null():
113+
return sge.null()
114+
106115
left_expr = _coerce_bool_to_int(left)
107116
right_expr = _coerce_bool_to_int(right)
108117
return sge.LT(this=left_expr, expression=right_expr)
109118

110119

111120
@register_binary_op(ops.le_op)
112121
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
122+
if left.expr == sge.null() or right.expr == sge.null():
123+
return sge.null()
124+
113125
left_expr = _coerce_bool_to_int(left)
114126
right_expr = _coerce_bool_to_int(right)
115127
return sge.LTE(this=left_expr, expression=right_expr)

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,9 @@ def _(expr: TypedExpr) -> sge.Expression:
388388

389389
@register_binary_op(ops.add_op)
390390
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
391+
if left.expr == sge.null() or right.expr == sge.null():
392+
return sge.null()
393+
391394
if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE:
392395
# String addition
393396
return sge.Concat(expressions=[left.expr, right.expr])
@@ -442,6 +445,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
442445

443446
@register_binary_op(ops.floordiv_op)
444447
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
448+
if left.expr == sge.null() or right.expr == sge.null():
449+
return sge.null()
450+
445451
left_expr = _coerce_bool_to_int(left)
446452
right_expr = _coerce_bool_to_int(right)
447453

@@ -525,6 +531,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
525531

526532
@register_binary_op(ops.mul_op)
527533
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
534+
if left.expr == sge.null() or right.expr == sge.null():
535+
return sge.null()
536+
528537
left_expr = _coerce_bool_to_int(left)
529538
right_expr = _coerce_bool_to_int(right)
530539

@@ -548,6 +557,9 @@ def _(expr: TypedExpr, n_digits: TypedExpr) -> sge.Expression:
548557

549558
@register_binary_op(ops.sub_op)
550559
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
560+
if left.expr == sge.null() or right.expr == sge.null():
561+
return sge.null()
562+
551563
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
552564
left_expr = _coerce_bool_to_int(left)
553565
right_expr = _coerce_bool_to_int(right)

tests/unit/core/compile/sqlglot/aggregations/test_windows.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_apply_window_if_present_range_bounded(self):
127127
)
128128
self.assertEqual(
129129
result.sql(dialect="bigquery"),
130-
"value OVER (ORDER BY `col1` ASC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)",
130+
"value OVER (ORDER BY `col1` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)",
131131
)
132132

133133
def test_apply_window_if_present_range_bounded_timedelta(self):
@@ -142,7 +142,7 @@ def test_apply_window_if_present_range_bounded_timedelta(self):
142142
)
143143
self.assertEqual(
144144
result.sql(dialect="bigquery"),
145-
"value OVER (ORDER BY `col1` ASC NULLS LAST RANGE BETWEEN 86400000000 PRECEDING AND 43200000000 FOLLOWING)",
145+
"value OVER (ORDER BY `col1` ASC RANGE BETWEEN 86400000000 PRECEDING AND 43200000000 FOLLOWING)",
146146
)
147147

148148
def test_apply_window_if_present_all_params(self):

tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ WITH `bfcte_0` AS (
2222
SELECT
2323
*,
2424
CASE
25-
WHEN SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER (
26-
PARTITION BY `bfcol_9`
27-
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
28-
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
25+
WHEN COALESCE(
26+
SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER (
27+
PARTITION BY `bfcol_9`
28+
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
29+
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
30+
),
31+
0
2932
) < 3
3033
THEN NULL
3134
ELSE COALESCE(
@@ -42,10 +45,13 @@ WITH `bfcte_0` AS (
4245
SELECT
4346
*,
4447
CASE
45-
WHEN SUM(CAST(NOT `bfcol_8` IS NULL AS INT64)) OVER (
46-
PARTITION BY `bfcol_9`
47-
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
48-
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
48+
WHEN COALESCE(
49+
SUM(CAST(NOT `bfcol_8` IS NULL AS INT64)) OVER (
50+
PARTITION BY `bfcol_9`
51+
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
52+
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
53+
),
54+
0
4955
) < 3
5056
THEN NULL
5157
ELSE COALESCE(

tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,17 @@ WITH `bfcte_0` AS (
66
SELECT
77
*,
88
CASE
9-
WHEN SUM(CAST(NOT `bfcol_1` IS NULL AS INT64)) OVER (
10-
ORDER BY UNIX_MICROS(`bfcol_0`) ASC NULLS LAST
11-
RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW
9+
WHEN COALESCE(
10+
SUM(CAST(NOT `bfcol_1` IS NULL AS INT64)) OVER (
11+
ORDER BY UNIX_MICROS(`bfcol_0`) ASC
12+
RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW
13+
),
14+
0
1215
) < 1
1316
THEN NULL
1417
ELSE COALESCE(
1518
SUM(`bfcol_1`) OVER (
16-
ORDER BY UNIX_MICROS(`bfcol_0`) ASC NULLS LAST
19+
ORDER BY UNIX_MICROS(`bfcol_0`) ASC
1720
RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW
1821
),
1922
0

tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_skips_nulls_op/out.sql

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ WITH `bfcte_0` AS (
77
SELECT
88
*,
99
CASE
10-
WHEN SUM(CAST(NOT `int64_col` IS NULL AS INT64)) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) < 3
10+
WHEN COALESCE(
11+
SUM(CAST(NOT `int64_col` IS NULL AS INT64)) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW),
12+
0
13+
) < 3
1114
THEN NULL
1215
ELSE COALESCE(
1316
SUM(`int64_col`) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW),

0 commit comments

Comments
 (0)