Skip to content

Commit 9164faa

Browse files
authored
test: improve unit test coverage for the windows compiler (#1997)
1 parent 59c52a5 commit 9164faa

File tree

6 files changed

+38
-24
lines changed

6 files changed

+38
-24
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,15 @@ def compile(
3737
return UNARY_OP_REGISTRATION[op](op, column, window=window)
3838

3939

40+
@UNARY_OP_REGISTRATION.register(agg_ops.CountOp)
41+
def _(
42+
op: agg_ops.CountOp,
43+
column: typed_expr.TypedExpr,
44+
window: typing.Optional[window_spec.WindowSpec] = None,
45+
) -> sge.Expression:
46+
return apply_window_if_present(sge.func("COUNT", column.expr), window)
47+
48+
4049
@UNARY_OP_REGISTRATION.register(agg_ops.SumOp)
4150
def _(
4251
op: agg_ops.SumOp,

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,9 @@ def compile_window(
336336
this=is_observation_expr, expression=expr
337337
)
338338
is_observation = ir._cast(is_observation_expr, "INT64")
339+
observation_count = windows.apply_window_if_present(
340+
sge.func("SUM", is_observation), window_spec
341+
)
339342
else:
340343
# Operations like count treat even NULLs as valid observations
341344
# for the sake of min_periods notnull is just used to convert
@@ -344,10 +347,10 @@ def compile_window(
344347
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
345348
"INT64",
346349
)
350+
observation_count = windows.apply_window_if_present(
351+
sge.func("COUNT", is_observation), window_spec
352+
)
347353

348-
observation_count = windows.apply_window_if_present(
349-
sge.func("SUM", is_observation), window_spec
350-
)
351354
clauses.append(
352355
(
353356
observation_count < sge.convert(window_spec.min_periods),

tests/system/small/engines/test_windowing.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,23 @@ def test_engines_with_offsets(
3535
assert_equivalence_execution(result.node, REFERENCE_ENGINE, engine)
3636

3737

38+
@pytest.mark.parametrize("never_skip_nulls", [True, False])
39+
@pytest.mark.parametrize("agg_op", [agg_ops.sum_op, agg_ops.count_op])
3840
def test_engines_with_rows_window(
3941
scalars_array_value: array_value.ArrayValue,
4042
bigquery_client: bigquery.Client,
43+
never_skip_nulls,
44+
agg_op,
4145
):
4246
window = window_spec.WindowSpec(
4347
bounds=window_spec.RowsWindowBounds.from_window_size(3, "left"),
4448
)
4549
window_node = nodes.WindowOpNode(
4650
child=scalars_array_value.node,
47-
expression=expression.UnaryAggregation(
48-
agg_ops.sum_op, expression.deref("int64_too")
49-
),
51+
expression=expression.UnaryAggregation(agg_op, expression.deref("int64_too")),
5052
window_spec=window,
51-
output_name=identifiers.ColumnId("sum_int64"),
52-
never_skip_nulls=False,
53+
output_name=identifiers.ColumnId("agg_int64"),
54+
never_skip_nulls=never_skip_nulls,
5355
skip_reproject_unsafe=False,
5456
)
5557

tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_w_min_periods/out.sql renamed to tests/unit/core/compile/sqlglot/snapshots/test_compile_window/test_compile_window_wo_skips_nulls_op/out.sql

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,14 @@ WITH `bfcte_0` AS (
77
SELECT
88
*,
99
CASE
10-
WHEN SUM(CAST(NOT `bfcol_0` IS NULL AS INT64)) OVER (
10+
WHEN COUNT(CAST(NOT `bfcol_0` IS NULL AS INT64)) OVER (
1111
ORDER BY `bfcol_1` IS NULL ASC NULLS LAST, `bfcol_1` ASC NULLS LAST
12-
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
13-
) < 3
12+
ROWS BETWEEN 4 PRECEDING AND CURRENT ROW
13+
) < 5
1414
THEN NULL
15-
ELSE COALESCE(
16-
SUM(`bfcol_0`) OVER (
17-
ORDER BY `bfcol_1` IS NULL ASC NULLS LAST, `bfcol_1` ASC NULLS LAST
18-
ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
19-
),
20-
0
15+
ELSE COUNT(`bfcol_0`) OVER (
16+
ORDER BY `bfcol_1` IS NULL ASC NULLS LAST, `bfcol_1` ASC NULLS LAST
17+
ROWS BETWEEN 4 PRECEDING AND CURRENT ROW
2118
)
2219
END AS `bfcol_4`
2320
FROM `bfcte_0`

tests/unit/core/compile/sqlglot/test_compile_window.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,20 @@
3030
)
3131

3232

33-
def test_compile_window_w_rolling(scalar_types_df: bpd.DataFrame, snapshot):
33+
def test_compile_window_w_skips_nulls_op(scalar_types_df: bpd.DataFrame, snapshot):
3434
bf_df = scalar_types_df[["int64_col"]].sort_index()
35+
# The SumOp's skips_nulls is True
3536
result = bf_df.rolling(window=3).sum()
3637
snapshot.assert_match(result.sql, "out.sql")
3738

3839

40+
def test_compile_window_wo_skips_nulls_op(scalar_types_df: bpd.DataFrame, snapshot):
41+
bf_df = scalar_types_df[["int64_col"]].sort_index()
42+
# The CountOp's skips_nulls is False
43+
result = bf_df.rolling(window=5).count()
44+
snapshot.assert_match(result.sql, "out.sql")
45+
46+
3947
def test_compile_window_w_groupby_rolling(scalar_types_df: bpd.DataFrame, snapshot):
4048
bf_df = scalar_types_df[["bool_col", "int64_col"]].sort_index()
4149
result = (
@@ -46,13 +54,8 @@ def test_compile_window_w_groupby_rolling(scalar_types_df: bpd.DataFrame, snapsh
4654
snapshot.assert_match(result.sql, "out.sql")
4755

4856

49-
def test_compile_window_w_min_periods(scalar_types_df: bpd.DataFrame, snapshot):
50-
bf_df = scalar_types_df[["int64_col"]].sort_index()
51-
result = bf_df.expanding(min_periods=3).sum()
52-
snapshot.assert_match(result.sql, "out.sql")
53-
54-
5557
def test_compile_window_w_range_rolling(compiler_session, snapshot):
58+
# TODO: use `duration_col` instead.
5659
values = np.arange(20)
5760
pd_df = pd.DataFrame(
5861
{

0 commit comments

Comments
 (0)