From 9620195799dca5286c87d830f6466a1ba986f7ce Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Fri, 15 Aug 2025 00:18:37 +0000 Subject: [PATCH] test: improve unit test coverage for the windows compiler --- .../sqlglot/aggregations/unary_compiler.py | 9 +++++++++ bigframes/core/compile/sqlglot/compiler.py | 9 ++++++--- tests/system/small/engines/test_windowing.py | 12 +++++++----- .../out.sql | 0 .../out.sql | 15 ++++++--------- .../core/compile/sqlglot/test_compile_window.py | 17 ++++++++++------- 6 files changed, 38 insertions(+), 24 deletions(-) rename tests/unit/core/compile/sqlglot/snapshots/test_compile_window/{test_compile_window_w_rolling => test_compile_window_w_skips_nulls_op}/out.sql (100%) rename tests/unit/core/compile/sqlglot/snapshots/test_compile_window/{test_compile_window_w_min_periods => test_compile_window_wo_skips_nulls_op}/out.sql (56%) diff --git a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py index c65c971bfa..c7eb84cba6 100644 --- a/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/aggregations/unary_compiler.py @@ -37,6 +37,15 @@ def compile( return UNARY_OP_REGISTRATION[op](op, column, window=window) +@UNARY_OP_REGISTRATION.register(agg_ops.CountOp) +def _( + op: agg_ops.CountOp, + column: typed_expr.TypedExpr, + window: typing.Optional[window_spec.WindowSpec] = None, +) -> sge.Expression: + return apply_window_if_present(sge.func("COUNT", column.expr), window) + + @UNARY_OP_REGISTRATION.register(agg_ops.SumOp) def _( op: agg_ops.SumOp, diff --git a/bigframes/core/compile/sqlglot/compiler.py b/bigframes/core/compile/sqlglot/compiler.py index 1a8455176a..b4dc6174be 100644 --- a/bigframes/core/compile/sqlglot/compiler.py +++ b/bigframes/core/compile/sqlglot/compiler.py @@ -336,6 +336,9 @@ def compile_window( this=is_observation_expr, expression=expr ) is_observation = ir._cast(is_observation_expr, "INT64") + observation_count = windows.apply_window_if_present( + sge.func("SUM", is_observation), window_spec + ) else: # Operations like count treat even NULLs as valid observations # for the sake of min_periods notnull is just used to convert @@ -344,10 +347,10 @@ def compile_window( sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())), "INT64", ) + observation_count = windows.apply_window_if_present( + sge.func("COUNT", is_observation), window_spec + ) - observation_count = windows.apply_window_if_present( - sge.func("SUM", is_observation), window_spec - ) clauses.append( ( observation_count < sge.convert(window_spec.min_periods), diff --git a/tests/system/small/engines/test_windowing.py b/tests/system/small/engines/test_windowing.py index 3712e4c047..a5f20a47cd 100644 --- a/tests/system/small/engines/test_windowing.py +++ b/tests/system/small/engines/test_windowing.py @@ -35,21 +35,23 @@ def test_engines_with_offsets( assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine) +@pytest.mark.parametrize("never_skip_nulls", [True, False]) +@pytest.mark.parametrize("agg_op", [agg_ops.sum_op, agg_ops.count_op]) def test_engines_with_rows_window( scalars_array_value: array_value.ArrayValue, bigquery_client: bigquery.Client, + never_skip_nulls, + agg_op, ): window = window_spec.WindowSpec( bounds=window_spec.RowsWindowBounds.from_window_size(3, "left"), ) window_node = nodes.WindowOpNode( child=scalars_array_value.node, - expression=expression.UnaryAggregation( - agg_ops.sum_op, expression.deref("int64_too") - ), + expression=expression.UnaryAggregation(agg_op, expression.deref("int64_too")), window_spec=window, - output_name=identifiers.ColumnId("sum_int64"), - never_skip_nulls=False, + output_name=identifiers.ColumnId("agg_int64"), + never_skip_nulls=never_skip_nulls, skip_reproject_unsafe=False, ) diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_rolling/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_skips_nulls_op/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_rolling/out.sql rename to tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_skips_nulls_op/out.sql diff --git a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_min_periods/out.sql b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_wo_skips_nulls_op/out.sql similarity index 56% rename from tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_min_periods/out.sql rename to tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_wo_skips_nulls_op/out.sql index 5885f5ab3c..1d5d9a9e45 100644 --- a/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_min_periods/out.sql +++ b/tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_wo_skips_nulls_op/out.sql @@ -7,17 +7,14 @@ WITH `bfcte_0` AS ( SELECT *, CASE - WHEN SUM(CAST(NOT `bfcol_0` IS NULL AS INT64)) OVER ( + WHEN COUNT(CAST(NOT `bfcol_0` IS NULL AS INT64)) OVER ( ORDER BY `bfcol_1` IS NULL ASC NULLS LAST, `bfcol_1` ASC NULLS LAST - ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - ) < 3 + ROWS BETWEEN 4 PRECEDING AND CURRENT ROW + ) < 5 THEN NULL - ELSE COALESCE( - SUM(`bfcol_0`) OVER ( - ORDER BY `bfcol_1` IS NULL ASC NULLS LAST, `bfcol_1` ASC NULLS LAST - ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW - ), - 0 + ELSE COUNT(`bfcol_0`) OVER ( + ORDER BY `bfcol_1` IS NULL ASC NULLS LAST, `bfcol_1` ASC NULLS LAST + ROWS BETWEEN 4 PRECEDING AND CURRENT ROW ) END AS `bfcol_4` FROM `bfcte_0` diff --git a/tests/unit/core/compile/sqlglot/test_compile_window.py b/tests/unit/core/compile/sqlglot/test_compile_window.py index 5a6e3e5322..1fc70dc30f 100644 --- a/tests/unit/core/compile/sqlglot/test_compile_window.py +++ b/tests/unit/core/compile/sqlglot/test_compile_window.py @@ -30,12 +30,20 @@ ) -def test_compile_window_w_rolling(scalar_types_df: bpd.DataFrame, snapshot): +def test_compile_window_w_skips_nulls_op(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col"]].sort_index() + # The SumOp's skips_nulls is True result = bf_df.rolling(window=3).sum() snapshot.assert_match(result.sql, "out.sql") +def test_compile_window_wo_skips_nulls_op(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col"]].sort_index() + # The CountOp's skips_nulls is False + result = bf_df.rolling(window=5).count() + snapshot.assert_match(result.sql, "out.sql") + + def test_compile_window_w_groupby_rolling(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["bool_col", "int64_col"]].sort_index() result = ( @@ -46,13 +54,8 @@ def test_compile_window_w_groupby_rolling(scalar_types_df: bpd.DataFrame, snapsh snapshot.assert_match(result.sql, "out.sql") -def test_compile_window_w_min_periods(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col"]].sort_index() - result = bf_df.expanding(min_periods=3).sum() - snapshot.assert_match(result.sql, "out.sql") - - def test_compile_window_w_range_rolling(compiler_session, snapshot): + # TODO: use `duration_col` instead. values = np.arange(20) pd_df = pd.DataFrame( {