Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions bigframes/core/compile/sqlglot/aggregations/windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def apply_window_if_present(
order_by = None
elif window.is_range_bounded:
order_by = get_window_order_by((window.ordering[0],))
order_by = remove_null_ordering_for_range_windows(order_by)
else:
order_by = get_window_order_by(window.ordering)

Expand Down Expand Up @@ -150,6 +151,30 @@ def get_window_order_by(
return tuple(order_by)


def remove_null_ordering_for_range_windows(
order_by: typing.Optional[tuple[sge.Ordered, ...]],
) -> typing.Optional[tuple[sge.Ordered, ...]]:
"""Removes NULL FIRST/LAST from ORDER BY expressions in RANGE windows.
Here's the support matrix:
✅ sum(x) over (order by y desc nulls last)
🚫 sum(x) over (order by y asc nulls last)
✅ sum(x) over (order by y asc nulls first)
🚫 sum(x) over (order by y desc nulls first)
"""
if order_by is None:
return None

new_order_by = []
for key in order_by:
kargs = key.args
if kargs.get("desc") is True and kargs.get("nulls_first", False):
kargs["nulls_first"] = False
elif kargs.get("desc") is False and not kargs.setdefault("nulls_first", True):
kargs["nulls_first"] = True
new_order_by.append(sge.Ordered(**kargs))
return tuple(new_order_by)


def _get_window_bounds(
value, is_preceding: bool
) -> tuple[typing.Union[str, sge.Expression], typing.Optional[str]]:
Expand Down
3 changes: 3 additions & 0 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotI
observation_count = windows.apply_window_if_present(
sge.func("SUM", is_observation), window_spec
)
observation_count = sge.func(
"COALESCE", observation_count, sge.convert(0)
)
else:
# Operations like count treat even NULLs as valid observations
# for the sake of min_periods notnull is just used to convert
Expand Down
12 changes: 12 additions & 0 deletions bigframes/core/compile/sqlglot/expressions/comparison_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,27 +89,39 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:

@register_binary_op(ops.ge_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if left.expr == sge.null() or right.expr == sge.null():
return sge.null()

left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)
return sge.GTE(this=left_expr, expression=right_expr)


@register_binary_op(ops.gt_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if left.expr == sge.null() or right.expr == sge.null():
return sge.null()

left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)
return sge.GT(this=left_expr, expression=right_expr)


@register_binary_op(ops.lt_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if left.expr == sge.null() or right.expr == sge.null():
return sge.null()

left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)
return sge.LT(this=left_expr, expression=right_expr)


@register_binary_op(ops.le_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if left.expr == sge.null() or right.expr == sge.null():
return sge.null()

left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)
return sge.LTE(this=left_expr, expression=right_expr)
Expand Down
12 changes: 12 additions & 0 deletions bigframes/core/compile/sqlglot/expressions/numeric_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,9 @@ def _(expr: TypedExpr) -> sge.Expression:

@register_binary_op(ops.add_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if left.expr == sge.null() or right.expr == sge.null():
return sge.null()

if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE:
# String addition
return sge.Concat(expressions=[left.expr, right.expr])
Expand Down Expand Up @@ -442,6 +445,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:

@register_binary_op(ops.floordiv_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if left.expr == sge.null() or right.expr == sge.null():
return sge.null()

left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)

Expand Down Expand Up @@ -525,6 +531,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:

@register_binary_op(ops.mul_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if left.expr == sge.null() or right.expr == sge.null():
return sge.null()

left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)

Expand All @@ -548,6 +557,9 @@ def _(expr: TypedExpr, n_digits: TypedExpr) -> sge.Expression:

@register_binary_op(ops.sub_op)
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
if left.expr == sge.null() or right.expr == sge.null():
return sge.null()

if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
left_expr = _coerce_bool_to_int(left)
right_expr = _coerce_bool_to_int(right)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_apply_window_if_present_range_bounded(self):
)
self.assertEqual(
result.sql(dialect="bigquery"),
"value OVER (ORDER BY `col1` ASC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)",
"value OVER (ORDER BY `col1` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)",
)

def test_apply_window_if_present_range_bounded_timedelta(self):
Expand All @@ -142,7 +142,7 @@ def test_apply_window_if_present_range_bounded_timedelta(self):
)
self.assertEqual(
result.sql(dialect="bigquery"),
"value OVER (ORDER BY `col1` ASC NULLS LAST RANGE BETWEEN 86400000000 PRECEDING AND 43200000000 FOLLOWING)",
"value OVER (ORDER BY `col1` ASC RANGE BETWEEN 86400000000 PRECEDING AND 43200000000 FOLLOWING)",
)

def test_apply_window_if_present_all_params(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,13 @@ WITH `bfcte_0` AS (
SELECT
*,
CASE
WHEN SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER (
PARTITION BY `bfcol_9`
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
WHEN COALESCE(
SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER (
PARTITION BY `bfcol_9`
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
),
0
) < 3
THEN NULL
ELSE COALESCE(
Expand All @@ -42,10 +45,13 @@ WITH `bfcte_0` AS (
SELECT
*,
CASE
WHEN SUM(CAST(NOT `bfcol_8` IS NULL AS INT64)) OVER (
PARTITION BY `bfcol_9`
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
WHEN COALESCE(
SUM(CAST(NOT `bfcol_8` IS NULL AS INT64)) OVER (
PARTITION BY `bfcol_9`
ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST
ROWS BETWEEN 3 PRECEDING AND CURRENT ROW
),
0
) < 3
THEN NULL
ELSE COALESCE(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@ WITH `bfcte_0` AS (
SELECT
*,
CASE
WHEN SUM(CAST(NOT `bfcol_1` IS NULL AS INT64)) OVER (
ORDER BY UNIX_MICROS(`bfcol_0`) ASC NULLS LAST
RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW
WHEN COALESCE(
SUM(CAST(NOT `bfcol_1` IS NULL AS INT64)) OVER (
ORDER BY UNIX_MICROS(`bfcol_0`) ASC
RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW
),
0
) < 1
THEN NULL
ELSE COALESCE(
SUM(`bfcol_1`) OVER (
ORDER BY UNIX_MICROS(`bfcol_0`) ASC NULLS LAST
ORDER BY UNIX_MICROS(`bfcol_0`) ASC
RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW
),
0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ WITH `bfcte_0` AS (
SELECT
*,
CASE
WHEN SUM(CAST(NOT `int64_col` IS NULL AS INT64)) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) < 3
WHEN COALESCE(
SUM(CAST(NOT `int64_col` IS NULL AS INT64)) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW),
0
) < 3
THEN NULL
ELSE COALESCE(
SUM(`int64_col`) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW),
Expand Down
Loading