diff --git a/bigframes/core/compile/sqlglot/aggregations/windows.py b/bigframes/core/compile/sqlglot/aggregations/windows.py index 6d6c507455..9c32788585 100644 --- a/bigframes/core/compile/sqlglot/aggregations/windows.py +++ b/bigframes/core/compile/sqlglot/aggregations/windows.py @@ -44,6 +44,7 @@ def apply_window_if_present( order_by = None elif window.is_range_bounded: order_by = get_window_order_by((window.ordering[0],)) + order_by = remove_null_ordering_for_range_windows(order_by) else: order_by = get_window_order_by(window.ordering) @@ -150,6 +151,30 @@ def get_window_order_by( return tuple(order_by) +def remove_null_ordering_for_range_windows( + order_by: typing.Optional[tuple[sge.Ordered, ...]], +) -> typing.Optional[tuple[sge.Ordered, ...]]: + """Removes NULL FIRST/LAST from ORDER BY expressions in RANGE windows. + Here's the support matrix: + ✅ sum(x) over (order by y desc nulls last) + 🚫 sum(x) over (order by y asc nulls last) + ✅ sum(x) over (order by y asc nulls first) + 🚫 sum(x) over (order by y desc nulls first) + """ + if order_by is None: + return None + + new_order_by = [] + for key in order_by: + kargs = key.args + if kargs.get("desc") is True and kargs.get("nulls_first", False): + kargs["nulls_first"] = False + elif kargs.get("desc") is False and not kargs.setdefault("nulls_first", True): + kargs["nulls_first"] = True + new_order_by.append(sge.Ordered(**kargs)) + return tuple(new_order_by) + + def _get_window_bounds( value, is_preceding: bool ) -> tuple[typing.Union[str, sge.Expression], typing.Optional[str]]: diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index b3b813a1c0..e77370892c 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -356,6 +356,9 @@ def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotI observation_count = windows.apply_window_if_present( sge.func("SUM", is_observation), window_spec ) + observation_count = sge.func( + "COALESCE", observation_count, sge.convert(0) + ) else: # Operations like count treat even NULLs as valid observations # for the sake of min_periods notnull is just used to convert diff --git a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py index 81bc9e0f56..8fda3b80dd 100644 --- a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py @@ -89,6 +89,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.ge_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if left.expr == sge.null() or right.expr == sge.null(): + return sge.null() + left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.GTE(this=left_expr, expression=right_expr) @@ -96,6 +99,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.gt_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if left.expr == sge.null() or right.expr == sge.null(): + return sge.null() + left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.GT(this=left_expr, expression=right_expr) @@ -103,6 +109,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.lt_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if left.expr == sge.null() or right.expr == sge.null(): + return sge.null() + left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.LT(this=left_expr, expression=right_expr) @@ -110,6 +119,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.le_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if left.expr == sge.null() or right.expr == sge.null(): + return sge.null() + left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) return sge.LTE(this=left_expr, expression=right_expr) diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index 16f7dec717..f7c763e207 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -388,6 +388,9 @@ def _(expr: TypedExpr) -> sge.Expression: @register_binary_op(ops.add_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if left.expr == sge.null() or right.expr == sge.null(): + return sge.null() + if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE: # String addition return sge.Concat(expressions=[left.expr, right.expr]) @@ -442,6 +445,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.floordiv_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if left.expr == sge.null() or right.expr == sge.null(): + return sge.null() + left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) @@ -525,6 +531,9 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: @register_binary_op(ops.mul_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if left.expr == sge.null() or right.expr == sge.null(): + return sge.null() + left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) @@ -548,6 +557,9 @@ def _(expr: TypedExpr, n_digits: TypedExpr) -> sge.Expression: @register_binary_op(ops.sub_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if left.expr == sge.null() or right.expr == sge.null(): + return sge.null() + if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): left_expr = _coerce_bool_to_int(left) right_expr = _coerce_bool_to_int(right) diff --git a/tests/unit/core/compile/sqlglot/aggregations/test_windows.py b/tests/unit/core/compile/sqlglot/aggregations/test_windows.py index e6343a63d7..d1204c6010 100644 --- a/tests/unit/core/compile/sqlglot/aggregations/test_windows.py +++ b/tests/unit/core/compile/sqlglot/aggregations/test_windows.py @@ -127,7 +127,7 @@ def test_apply_window_if_present_range_bounded(self): ) self.assertEqual( result.sql(dialect="bigquery"), - "value OVER (ORDER BY `col1` ASC NULLS LAST RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", + "value OVER (ORDER BY `col1` ASC RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW)", ) def test_apply_window_if_present_range_bounded_timedelta(self): @@ -142,7 +142,7 @@ def test_apply_window_if_present_range_bounded_timedelta(self): ) self.assertEqual( result.sql(dialect="bigquery"), - "value OVER (ORDER BY `col1` ASC NULLS LAST RANGE BETWEEN 86400000000 PRECEDING AND 43200000000 FOLLOWING)", + "value OVER (ORDER BY `col1` ASC RANGE BETWEEN 86400000000 PRECEDING AND 43200000000 FOLLOWING)", ) def test_apply_window_if_present_all_params(self): diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql index e8fabd1129..0dca6d9d49 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_groupby_rolling/out.sql @@ -22,10 +22,13 @@ WITH `bfcte_0` AS ( SELECT *, CASE - WHEN SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER ( - PARTITION BY `bfcol_9` - ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST - ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + WHEN COALESCE( + SUM(CAST(NOT `bfcol_7` IS NULL AS INT64)) OVER ( + PARTITION BY `bfcol_9` + ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST + ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + ), + 0 ) < 3 THEN NULL ELSE COALESCE( @@ -42,10 +45,13 @@ WITH `bfcte_0` AS ( SELECT *, CASE - WHEN SUM(CAST(NOT `bfcol_8` IS NULL AS INT64)) OVER ( - PARTITION BY `bfcol_9` - ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST - ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + WHEN COALESCE( + SUM(CAST(NOT `bfcol_8` IS NULL AS INT64)) OVER ( + PARTITION BY `bfcol_9` + ORDER BY `bfcol_9` ASC NULLS LAST, `rowindex` ASC NULLS LAST + ROWS BETWEEN 3 PRECEDING AND CURRENT ROW + ), + 0 ) < 3 THEN NULL ELSE COALESCE( diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql index 581c81c6b4..fe4cea08cb 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_range_rolling/out.sql @@ -6,14 +6,17 @@ WITH `bfcte_0` AS ( SELECT *, CASE - WHEN SUM(CAST(NOT `bfcol_1` IS NULL AS INT64)) OVER ( - ORDER BY UNIX_MICROS(`bfcol_0`) ASC NULLS LAST - RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW + WHEN COALESCE( + SUM(CAST(NOT `bfcol_1` IS NULL AS INT64)) OVER ( + ORDER BY UNIX_MICROS(`bfcol_0`) ASC + RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW + ), + 0 ) < 1 THEN NULL ELSE COALESCE( SUM(`bfcol_1`) OVER ( - ORDER BY UNIX_MICROS(`bfcol_0`) ASC NULLS LAST + ORDER BY UNIX_MICROS(`bfcol_0`) ASC RANGE BETWEEN 2999999 PRECEDING AND CURRENT ROW ), 0 diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_skips_nulls_op/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_skips_nulls_op/out.sql index 788eb49ddf..bf1e76c55c 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_skips_nulls_op/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_skips_nulls_op/out.sql @@ -7,7 +7,10 @@ WITH `bfcte_0` AS ( SELECT *, CASE - WHEN SUM(CAST(NOT `int64_col` IS NULL AS INT64)) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW) < 3 + WHEN COALESCE( + SUM(CAST(NOT `int64_col` IS NULL AS INT64)) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW), + 0 + ) < 3 THEN NULL ELSE COALESCE( SUM(`int64_col`) OVER (ORDER BY `rowindex` ASC NULLS LAST ROWS BETWEEN 2 PRECEDING AND CURRENT ROW),