Skip to content

Commit 210cb2f

Browse files
committed
Merge branch 'main' into shuowei-anywidget-fix-empty-index
2 parents f4357b5 + dbe8e7e commit 210cb2f

File tree

12 files changed

+149
-15
lines changed

12 files changed

+149
-15
lines changed

bigframes/core/compile/sqlglot/aggregations/windows.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def apply_window_if_present(
4444
order_by = None
4545
elif window.is_range_bounded:
4646
order_by = get_window_order_by((window.ordering[0],))
47+
order_by = remove_null_ordering_for_range_windows(order_by)
4748
else:
4849
order_by = get_window_order_by(window.ordering)
4950

@@ -150,6 +151,30 @@ def get_window_order_by(
150151
return tuple(order_by)
151152

152153

154+
def remove_null_ordering_for_range_windows(
155+
order_by: typing.Optional[tuple[sge.Ordered, ...]],
156+
) -> typing.Optional[tuple[sge.Ordered, ...]]:
157+
"""Removes NULL FIRST/LAST from ORDER BY expressions in RANGE windows.
158+
Here's the support matrix:
159+
✅ sum(x) over (order by y desc nulls last)
160+
🚫 sum(x) over (order by y asc nulls last)
161+
✅ sum(x) over (order by y asc nulls first)
162+
🚫 sum(x) over (order by y desc nulls first)
163+
"""
164+
if order_by is None:
165+
return None
166+
167+
new_order_by = []
168+
for key in order_by:
169+
kargs = key.args
170+
if kargs.get("desc") is True and kargs.get("nulls_first", False):
171+
kargs["nulls_first"] = False
172+
elif kargs.get("desc") is False and not kargs.setdefault("nulls_first", True):
173+
kargs["nulls_first"] = True
174+
new_order_by.append(sge.Ordered(**kargs))
175+
return tuple(new_order_by)
176+
177+
153178
def _get_window_bounds(
154179
value, is_preceding: bool
155180
) -> tuple[typing.Union[str, sge.Expression], typing.Optional[str]]:

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,9 @@ def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotI
356356
observation_count = windows.apply_window_if_present(
357357
sge.func("SUM", is_observation), window_spec
358358
)
359+
observation_count = sge.func(
360+
"COALESCE", observation_count, sge.convert(0)
361+
)
359362
else:
360363
# Operations like count treat even NULLs as valid observations
361364
# for the sake of min_periods notnull is just used to convert

bigframes/core/compile/sqlglot/expressions/comparison_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,27 +89,39 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
8989

9090
@register_binary_op(ops.ge_op)
9191
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
92+
if left.expr == sge.null() or right.expr == sge.null():
93+
return sge.null()
94+
9295
left_expr = _coerce_bool_to_int(left)
9396
right_expr = _coerce_bool_to_int(right)
9497
return sge.GTE(this=left_expr, expression=right_expr)
9598

9699

97100
@register_binary_op(ops.gt_op)
98101
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
102+
if left.expr == sge.null() or right.expr == sge.null():
103+
return sge.null()
104+
99105
left_expr = _coerce_bool_to_int(left)
100106
right_expr = _coerce_bool_to_int(right)
101107
return sge.GT(this=left_expr, expression=right_expr)
102108

103109

104110
@register_binary_op(ops.lt_op)
105111
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
112+
if left.expr == sge.null() or right.expr == sge.null():
113+
return sge.null()
114+
106115
left_expr = _coerce_bool_to_int(left)
107116
right_expr = _coerce_bool_to_int(right)
108117
return sge.LT(this=left_expr, expression=right_expr)
109118

110119

111120
@register_binary_op(ops.le_op)
112121
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
122+
if left.expr == sge.null() or right.expr == sge.null():
123+
return sge.null()
124+
113125
left_expr = _coerce_bool_to_int(left)
114126
right_expr = _coerce_bool_to_int(right)
115127
return sge.LTE(this=left_expr, expression=right_expr)

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,19 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
140140
return sge.Coalesce(this=left.expr, expressions=[right.expr])
141141

142142

143+
@register_binary_op(ops.BinaryRemoteFunctionOp, pass_op=True)
144+
def _(
145+
left: TypedExpr, right: TypedExpr, op: ops.BinaryRemoteFunctionOp
146+
) -> sge.Expression:
147+
routine_ref = op.function_def.routine_ref
148+
# Quote project, dataset, and routine IDs to avoid keyword clashes.
149+
func_name = (
150+
f"`{routine_ref.project}`.`{routine_ref.dataset_id}`.`{routine_ref.routine_id}`"
151+
)
152+
153+
return sge.func(func_name, left.expr, right.expr)
154+
155+
143156
@register_nary_op(ops.case_when_op)
144157
def _(*cases_and_outputs: TypedExpr) -> sge.Expression:
145158
# Need to upcast BOOL to INT if any output is numeric

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,9 @@ def _(expr: TypedExpr) -> sge.Expression:
388388

389389
@register_binary_op(ops.add_op)
390390
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
391+
if left.expr == sge.null() or right.expr == sge.null():
392+
return sge.null()
393+
391394
if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE:
392395
# String addition
393396
return sge.Concat(expressions=[left.expr, right.expr])
@@ -442,6 +445,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
442445

443446
@register_binary_op(ops.floordiv_op)
444447
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
448+
if left.expr == sge.null() or right.expr == sge.null():
449+
return sge.null()
450+
445451
left_expr = _coerce_bool_to_int(left)
446452
right_expr = _coerce_bool_to_int(right)
447453

@@ -525,6 +531,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
525531

526532
@register_binary_op(ops.mul_op)
527533
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
534+
if left.expr == sge.null() or right.expr == sge.null():
535+
return sge.null()
536+
528537
left_expr = _coerce_bool_to_int(left)
529538
right_expr = _coerce_bool_to_int(right)
530539

@@ -548,6 +557,9 @@ def _(expr: TypedExpr, n_digits: TypedExpr) -> sge.Expression:
548557

549558
@register_binary_op(ops.sub_op)
550559
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
560+
if left.expr == sge.null() or right.expr == sge.null():
561+
return sge.null()
562+
551563
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
552564
left_expr = _coerce_bool_to_int(left)
553565
right_expr = _coerce_bool_to_int(right)

docs/conf.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,12 @@
267267

268268
# https://sphinx-sitemap.readthedocs.io/en/latest/getting-started.html#usage
269269
html_baseurl = "https://dataframes.bigquery.dev/"
270+
sitemap_locales = [None]
271+
272+
# We don't have any immediate plans to translate the API reference, so omit the
273+
# language from the URLs.
274+
# https://sphinx-sitemap.readthedocs.io/en/latest/advanced-configuration.html#configuration-customizing-url-scheme
275+
sitemap_url_scheme = "{link}"
270276

271277
# -- Options for warnings ------------------------------------------------------
272278

tests/unit/core/compile/sqlglot/aggregations/test_windows.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def test_apply_window_if_present_range_bounded(self):
127127
)
128128
self.assertEqual(
129129
result.sql(dialect="bigquery"),
130-
"value OVER (ORDER BY `col1` ASC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)",
130+
"value OVER (ORDER BY `col1` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)",
131131
)
132132

133133
def test_apply_window_if_present_range_bounded_timedelta(self):
@@ -142,7 +142,7 @@ def test_apply_window_if_present_range_bounded_timedelta(self):
142142
)
143143
self.assertEqual(
144144
result.sql(dialect="bigquery"),
145-
"value OVER (ORDER BY `col1` ASC NULLS LAST RANGE BETWEEN 86400000000 PRECEDING AND 43200000000 FOLLOWING)",
145+
"value OVER (ORDER BY `col1` ASC RANGE BETWEEN 86400000000 PRECEDING AND 43200000000 FOLLOWING)",
146146
)
147147

148148
def test_apply_window_if_present_all_params(self):
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`float64_col`,
4+
`int64_col`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
`my_project`.`my_dataset`.`my_routine`(`int64_col`, `float64_col`) AS `bfcol_2`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_2` AS `int64_col`
14+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,43 @@ def test_astype_json_invalid(
168168
)
169169

170170

171+
def test_binary_remote_function_op(scalar_types_df: bpd.DataFrame, snapshot):
172+
from google.cloud import bigquery
173+
174+
from bigframes.functions import udf_def
175+
176+
bf_df = scalar_types_df[["int64_col", "float64_col"]]
177+
op = ops.BinaryRemoteFunctionOp(
178+
function_def=udf_def.BigqueryUdf(
179+
routine_ref=bigquery.RoutineReference.from_string(
180+
"my_project.my_dataset.my_routine"
181+
),
182+
signature=udf_def.UdfSignature(
183+
input_types=(
184+
udf_def.UdfField(
185+
"x",
186+
bigquery.StandardSqlDataType(
187+
type_kind=bigquery.StandardSqlTypeNames.INT64
188+
),
189+
),
190+
udf_def.UdfField(
191+
"y",
192+
bigquery.StandardSqlDataType(
193+
type_kind=bigquery.StandardSqlTypeNames.FLOAT64
194+
),
195+
),
196+
),
197+
output_bq_type=bigquery.StandardSqlDataType(
198+
type_kind=bigquery.StandardSqlTypeNames.FLOAT64
199+
),
200+
),
201+
)
202+
)
203+
sql = utils._apply_binary_op(bf_df, op, "int64_col", "float64_col")
204+
205+
snapshot.assert_match(sql, "out.sql")
206+
207+
171208
def test_case_when_op(scalar_types_df: bpd.DataFrame, snapshot):
172209
ops_map = {
173210
"single_case": ops.case_when_op.as_expr(

tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,13 @@ WITH `bfcte_0` AS (
2222
SELECT
2323
*,
2424
CASE
25-
WHEN SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER (
26-
PARTITION BY `bfcol_9`
27-
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
28-
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
25+
WHEN COALESCE(
26+
SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER (
27+
PARTITION BY `bfcol_9`
28+
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
29+
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
30+
),
31+
0
2932
) < 3
3033
THEN NULL
3134
ELSE COALESCE(
@@ -42,10 +45,13 @@ WITH `bfcte_0` AS (
4245
SELECT
4346
*,
4447
CASE
45-
WHEN SUM(CAST(NOT `bfcol_8` IS NULL AS INT64)) OVER (
46-
PARTITION BY `bfcol_9`
47-
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
48-
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
48+
WHEN COALESCE(
49+
SUM(CAST(NOT `bfcol_8` IS NULL AS INT64)) OVER (
50+
PARTITION BY `bfcol_9`
51+
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
52+
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
53+
),
54+
0
4955
) < 3
5056
THEN NULL
5157
ELSE COALESCE(

0 commit comments

Comments
 (0)